Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] adding tensor classes annotation for loss functions #1905

Open
wants to merge 28 commits into
base: main
Choose a base branch
from

Conversation

SandishKumarHN
Copy link

@SandishKumarHN SandishKumarHN commented Feb 13, 2024

Description

followup from this pull request
copy past:

We project on using https://github.com/Tensorclass to represent losses.

The advantage of tensorclass for losses instead of tensordict is that it will help us use all the features of tensordict while preserving type annotation or even completion.

Changes:
Check the out_keys of the loss;
Create a tensorclass with the respective fields;
Type the forward as returning that class (and/or a tensordict)
Add an argument to return the class in the constructor with the False value by default;
Update the docstrings (not done)
Write a little test to check that things work as expected (this test should be new and not parametrized - if we add one more parameter to the existing tests the code will be much longer and harder to follow, and the tests will run for a long time).

Types of changes

What types of changes does your code introduce? Remove all that do not apply:

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds core functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)
  • Documentation (update in the documentation)
  • Example (update in the folder of examples)

Checklist

Go over all the following points, and put an x in all the boxes that apply.
If you are unsure about any of these, don't hesitate to ask. We are here to help!

  • I have read the CONTRIBUTION guide (required)
  • My change requires a change to the documentation.
  • I have updated the tests accordingly (required for a bug fix or a new feature).
  • I have updated the documentation accordingly.

Copy link

pytorch-bot bot commented Feb 13, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/rl/1905

Note: Links to docs will display an error until the docs builds have been completed.

❌ 14 New Failures

As of commit 9b5f4e6 with merge base 87f3437 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Feb 13, 2024
@SandishKumarHN SandishKumarHN force-pushed the tensorclass-losses branch 3 times, most recently from 3c3bc29 to 0e993e4 Compare February 13, 2024 04:34
@vmoens vmoens added the enhancement New feature or request label Feb 13, 2024
@vmoens vmoens changed the title adding tensor classes annotation for loss functions [Feature] adding tensor classes annotation for loss functions Feb 13, 2024
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks promising!
I'd suggest adding the args to the docstrings and make a single test for each loss that we can get the tensorclass and the we can access the losses inf return_tensorclass is True

@@ -234,6 +250,7 @@ def __init__(
functional: bool = True,
actor: ProbabilisticTensorDictSequential = None,
critic: ProbabilisticTensorDictSequential = None,
return_tensorclass: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be added to the docstrings

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

working on it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vmoens added doctests for tensorclass changes. but I see some doctest issues and blockers. can you please help me resolve.

  • there are some existing doctest failures, we might need a separate task to address.
  • what would be the aggregate_loss for each loss within tensorclass?
  • there are some existing errors like
  1.    ```Cannot interpret 'torch.int64' as a data type```
    
  2.    ```'key "action_value" not found in TensorDict with keys [\'done\', \'logits\', \'observation\', \'reward\', \'state_value\', \'terminated\']' ```
    
  3.    ```NameError: name 'actor' is not defined```
    
  4. etc

@SandishKumarHN
Copy link
Author

@vmoens can you review once, build errors on resource not related to the PR.

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

New classes should be manually added to the doc
See docs/source/reference/objectives.rst

The feature seems untested to me, am I right?

Comment on lines 494 to 498

loss_function="l2",

delay_value=delay_value,

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why we need these

@@ -33,6 +35,34 @@
)


class LossContainerBase:
"""ContainerBase class loss tensorclass's."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That isn't very explicit. We should say what this class is about.

Also I think it should live in the common.py file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also wondering if we should not just make the base a tensorclass and inherit from it without creating new tensorclasses?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I try to make the base a tensorclass getting below error.

**********************************************************************
File "/home/sandish/rl/torchrl/objectives/a2c.py", line 144, in a2c.A2CLoss
Failed example:
    loss(data)
Exception raised:
    Traceback (most recent call last):
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/doctest.py", line 1334, in __run
        exec(compile(example.source, filename, "single",
      File "<doctest a2c.A2CLoss[21]>", line 1, in <module>
        loss(data)
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
        return self._call_impl(*args, **kwargs)
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1561, in _call_impl
        result = forward_call(*args, **kwargs)
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/site-packages/tensordict/_contextlib.py", line 126, in decorate_context
        return func(*args, **kwargs)
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/site-packages/tensordict/nn/common.py", line 291, in wrapper
        return func(_self, tensordict, *args, **kwargs)
      File "/home/sandish/rl/torchrl/objectives/a2c.py", line 503, in forward
        return A2CLosses._from_tensordict(td_out)
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/site-packages/tensordict/tensorclass.py", line 327, in wrapper
        raise ValueError(
    ValueError: Keys from the tensordict ({'loss_entropy', 'loss_objective', 'entropy', 'loss_critic'}) must correspond to the class attributes (set()).

Comment on lines 61 to 63
@property
def aggregate_loss(self):
return self.loss_critic + self.loss_objective + self.loss_entropy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to recode this

Comment on lines 43 to 48
def aggregate_loss(self):
result = 0.0
for key in self.__dataclass_attr__:
if key.startswith("loss_"):
result += getattr(self, key)
return result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be a property
Should always return a tensor
Something like

result = torch.zeros((), device=self.device)
...
return result

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing docstring for this method.

@@ -455,7 +497,7 @@ def _cached_detach_critic_network_params(self):
return self.critic_network_params.detach()

@dispatch()
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def forward(self, tensordict: TensorDictBase) -> A2CLosses:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def forward(self, tensordict: TensorDictBase) -> A2CLosses:
def forward(self, tensordict: TensorDictBase) -> A2CLosses | TensorDictBase:

Comment on lines 49 to 71
class LossContainerBase:
"""ContainerBase class loss tensorclass's."""

__getitem__ = TensorDictBase.__getitem__

def aggregate_loss(self):
result = 0.0
for key in self.__dataclass_attr__:
if key.startswith("loss_"):
result += getattr(self, key)
return result


@tensorclass
class SACLosses(LossContainerBase):
"""The tensorclass for The SACLoss Loss class."""

loss_actor: torch.Tensor
loss_value: torch.Tensor
loss_qvalue: torch.Tensor
alpha: torch.Tensor | None = None
loss_alpha: torch.Tensor | None = None
entropy: torch.Tensor | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -541,7 +581,7 @@ def out_keys(self, values):
self._out_keys = values

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def forward(self, tensordict: TensorDictBase) -> SACLosses:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def forward(self, tensordict: TensorDictBase) -> SACLosses:
def forward(self, tensordict: TensorDictBase) -> SACLosses | TensorDictBase:

Comment on lines 618 to 619
out["loss_value"] = loss_value.mean()
return TensorDict(out, [])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Comment on lines 32 to 54
class LossContainerBase:
"""ContainerBase class loss tensorclass's."""

__getitem__ = TensorDictBase.__getitem__

def aggregate_loss(self):
result = 0.0
for key in self.__dataclass_attr__:
if key.startswith("loss_"):
result += getattr(self, key)
return result


@tensorclass
class TD3Losses(LossContainerBase):
"""The tensorclass for The TD3 Loss class."""

loss_actor: torch.Tensor
loss_qvalue: torch.Tensor
target_value: torch.Tensor | None = None
state_action_value_actor: torch.Tensor | None = None
pred_value: torch.Tensor | None = None
next_state_value: torch.Tensor | None = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -453,7 +492,7 @@ def value_loss(self, tensordict):
return loss_qval, metadata

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
def forward(self, tensordict: TensorDictBase) -> TD3Losses:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def forward(self, tensordict: TensorDictBase) -> TD3Losses:
def forward(self, tensordict: TensorDictBase) -> TD3Losses | TensorDictBase:

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

New classes should be manually added to the doc
See docs/source/reference/objectives.rst

The feature seems untested to me, am I right?

@SandishKumarHN
Copy link
Author

SandishKumarHN commented Feb 29, 2024

@vmoens address most of your comments above, but doctests are failing with below error not caused by this PR changes.

File "/home/sandish/rl/torchrl/objectives/cql.py", line 128, in cql.CQLLoss
Failed example:
    loss = CQLLoss(actor, qvalue)
Exception raised:
    Traceback (most recent call last):
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/doctest.py", line 1334, in __run
        exec(compile(example.source, filename, "single",
      File "<doctest cql.CQLLoss[16]>", line 1, in <module>
        loss = CQLLoss(actor, qvalue)
      File "/home/sandish/rl/torchrl/objectives/cql.py", line 321, in __init__
        self.convert_to_functional(
      File "/home/sandish/rl/torchrl/objectives/common.py", line 289, in convert_to_functional
        params.apply(
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/site-packages/tensordict/nn/params.py", line 125, in new_func
        out = meth(*args, **kwargs)
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/site-packages/tensordict/base.py", line 3824, in apply
        return self._apply_nest(
      File "/home/sandish/.conda/envs/torch_rl/lib/python3.9/site-packages/tensordict/_td.py", line 659, in _apply_nest
        out = TensorDict(
    TypeError: __init__() got an unexpected keyword argument 'filter_empty'
  File "/pytorch/rl/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1704, in __getattr__
    raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
AttributeError: 'DDPGLoss' object has no attribute 'reduction'

Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great progress!
I'd like to brainstorm about naming of those classes, A2CLosses seems a bit awkward when we have A2CLoss which is a totally different thing.
I don't think we should rename "loss" in "loss_objective" as part of this PR

torchrl/objectives/a2c.py Outdated Show resolved Hide resolved
@@ -300,6 +322,8 @@ def __init__(
if gamma is not None:
raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR)
self.loss_critic_type = loss_critic_type
self.return_tensorclass = return_tensorclass
self.reduction = reduction
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

duplicate

torchrl/objectives/common.py Outdated Show resolved Hide resolved
@@ -32,6 +34,15 @@
VTrace,
)

@tensorclass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doens't it work if we make the base class a tensorclass?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, It doesn't work.

@@ -298,7 +309,7 @@ def _compare_and_expand(param):
# set the functional module: we need to convert the params to non-differentiable params
# otherwise they will appear twice in parameters
with params.apply(
self._make_meta_params, device=torch.device("meta"), filter_empty=False
self._make_meta_params, device=torch.device("meta")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this removed?

torchrl/objectives/common.py Outdated Show resolved Hide resolved
Comment on lines 41 to 42
loss_objective: torch.Tensor
loss: torch.Tensor
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why two?

torchrl/objectives/dqn.py Outdated Show resolved Hide resolved
Copy link
Contributor

@vmoens vmoens left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great progress!
I'd like to brainstorm about naming of those classes, A2CLosses seems a bit awkward when we have A2CLoss which is a totally different thing.
I don't think we should rename "loss" in "loss_objective" as part of this PR

review changes
@SandishKumarHN
Copy link
Author

@vmoens made changes based on your review, I still reduction is not being added to the test_cost.py file so all of the failures are related to that.

@SandishKumarHN SandishKumarHN force-pushed the tensorclass-losses branch 2 times, most recently from ad2b6c3 to 8b5e0ff Compare March 18, 2024 16:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants