diff --git a/test/test_cost.py b/test/test_cost.py index 065f9dff946..660b0b5b491 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -4217,9 +4217,6 @@ def test_discrete_sac_reduction(self, reduction): assert loss[key].shape == torch.Size([]) -@pytest.mark.skipif( - not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" -) class TestCrossQ(LossModuleTestBase): seed = 0 diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 9cffa28f4a4..8d28dc0e0b1 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -46,10 +46,15 @@ class CrossQLoss(LossModule): Presented in "CROSSQ: BATCH NORMALIZATION IN DEEP REINFORCEMENT LEARNING FOR GREATER SAMPLE EFFICIENCY AND SIMPLICITY" https://openreview.net/pdf?id=PczQtTsTIX + This class has three loss functions that will be called sequentially by the `forward` method: + :meth:`~.qvalue_loss`, :meth:`~.actor_loss` and :meth:`~.alpha_loss`. Alternatively, they can + be called by the user that order. + Args: actor_network (ProbabilisticActor): stochastic actor qvalue_network (TensorDictModule): Q(s, a) parametric model. This module typically outputs a ``"state_action_value"`` entry. + Keyword Args: num_qvalue_nets (integer, optional): number of Q-Value networks used. Defaults to ``2``. @@ -331,6 +336,10 @@ def __init__( @property def target_entropy_buffer(self): + """The target entropy. + + This value can be controlled via the `target_entropy` kwarg in the constructor. + """ return self.target_entropy @property @@ -467,6 +476,13 @@ def out_keys(self, values): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """The forward method. + + Computes successively the :meth:`~.qvalue_loss`, :meth:`~.actor_loss` and :meth:`~.alpha_loss`, and returns + a tensordict with these values along with the `"alpha"` value and the `"entropy"` value (detached). + To see what keys are expected in the input tensordict and what keys are expected as output, check the + class's `"in_keys"` and `"out_keys"` attributes. + """ shape = None if tensordict.ndimension() > 1: shape = tensordict.shape @@ -511,7 +527,17 @@ def _cached_detached_qvalue_params(self): def actor_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: - """Compute the actor loss.""" + """Compute the actor loss. + + + The actor loss should be computed after the :meth:`~.qvalue_loss` and before the `~.alpha_loss` which requires the `log_prob` field of the `metadata` returned by this method. + + Args: + tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields + are required for this to be computed. + + Returns: a differentiable tensor with the alpha loss along with a metadata dictionary containing the detached `"log_prob"` of the sampled action. + """ with set_exploration_type( ExplorationType.RANDOM ), self.actor_network_params.to_module(self.actor_network): @@ -540,7 +566,16 @@ def actor_loss( def qvalue_loss( self, tensordict: TensorDictBase ) -> Tuple[Tensor, Dict[str, Tensor]]: - """Compute the CrossQ-value loss.""" + """Compute the q-value loss. + + The q-value loss should be computed before the :meth:`~.actor_loss`. + + Args: + tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields + are required for this to be computed. + + Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing the detached `"td_error"` to be used for prioritized sampling. + """ # # compute next action with torch.no_grad(): with set_exploration_type( @@ -594,7 +629,15 @@ def qvalue_loss( return loss_qval, metadata def alpha_loss(self, log_prob: Tensor) -> Tensor: - """Compute the entropy loss.""" + """Compute the entropy loss. + + The entropy loss should be computed last. + + Args: + log_prob: a log-probability as computed by the :meth:`~.actor_loss` and returned in the `metadata`. + + Returns: a differentiable tensor with the entropy loss. + """ if self.target_entropy is not None: # we can compute this loss even if log_alpha is not a parameter alpha_loss = -self.log_alpha * (log_prob + self.target_entropy)