Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jul 9, 2024
1 parent d3c8b0e commit d3e0bb1
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 6 deletions.
3 changes: 0 additions & 3 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -4217,9 +4217,6 @@ def test_discrete_sac_reduction(self, reduction):
assert loss[key].shape == torch.Size([])


@pytest.mark.skipif(
not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}"
)
class TestCrossQ(LossModuleTestBase):
seed = 0

Expand Down
49 changes: 46 additions & 3 deletions torchrl/objectives/crossq.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,15 @@ class CrossQLoss(LossModule):
Presented in "CROSSQ: BATCH NORMALIZATION IN DEEP REINFORCEMENT LEARNING
FOR GREATER SAMPLE EFFICIENCY AND SIMPLICITY" https://openreview.net/pdf?id=PczQtTsTIX
This class has three loss functions that will be called sequentially by the `forward` method:
:meth:`~.qvalue_loss`, :meth:`~.actor_loss` and :meth:`~.alpha_loss`. Alternatively, they can
be called by the user that order.
Args:
actor_network (ProbabilisticActor): stochastic actor
qvalue_network (TensorDictModule): Q(s, a) parametric model.
This module typically outputs a ``"state_action_value"`` entry.
Keyword Args:
num_qvalue_nets (integer, optional): number of Q-Value networks used.
Defaults to ``2``.
Expand Down Expand Up @@ -331,6 +336,10 @@ def __init__(

@property
def target_entropy_buffer(self):
"""The target entropy.
This value can be controlled via the `target_entropy` kwarg in the constructor.
"""
return self.target_entropy

@property
Expand Down Expand Up @@ -467,6 +476,13 @@ def out_keys(self, values):

@dispatch
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
"""The forward method.
Computes successively the :meth:`~.qvalue_loss`, :meth:`~.actor_loss` and :meth:`~.alpha_loss`, and returns
a tensordict with these values along with the `"alpha"` value and the `"entropy"` value (detached).
To see what keys are expected in the input tensordict and what keys are expected as output, check the
class's `"in_keys"` and `"out_keys"` attributes.
"""
shape = None
if tensordict.ndimension() > 1:
shape = tensordict.shape
Expand Down Expand Up @@ -511,7 +527,17 @@ def _cached_detached_qvalue_params(self):
def actor_loss(
self, tensordict: TensorDictBase
) -> Tuple[Tensor, Dict[str, Tensor]]:
"""Compute the actor loss."""
"""Compute the actor loss.
The actor loss should be computed after the :meth:`~.qvalue_loss` and before the `~.alpha_loss` which requires the `log_prob` field of the `metadata` returned by this method.
Args:
tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields
are required for this to be computed.
Returns: a differentiable tensor with the alpha loss along with a metadata dictionary containing the detached `"log_prob"` of the sampled action.
"""
with set_exploration_type(
ExplorationType.RANDOM
), self.actor_network_params.to_module(self.actor_network):
Expand Down Expand Up @@ -540,7 +566,16 @@ def actor_loss(
def qvalue_loss(
self, tensordict: TensorDictBase
) -> Tuple[Tensor, Dict[str, Tensor]]:
"""Compute the CrossQ-value loss."""
"""Compute the q-value loss.
The q-value loss should be computed before the :meth:`~.actor_loss`.
Args:
tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields
are required for this to be computed.
Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing the detached `"td_error"` to be used for prioritized sampling.
"""
# # compute next action
with torch.no_grad():
with set_exploration_type(
Expand Down Expand Up @@ -594,7 +629,15 @@ def qvalue_loss(
return loss_qval, metadata

def alpha_loss(self, log_prob: Tensor) -> Tensor:
"""Compute the entropy loss."""
"""Compute the entropy loss.
The entropy loss should be computed last.
Args:
log_prob: a log-probability as computed by the :meth:`~.actor_loss` and returned in the `metadata`.
Returns: a differentiable tensor with the entropy loss.
"""
if self.target_entropy is not None:
# we can compute this loss even if log_alpha is not a parameter
alpha_loss = -self.log_alpha * (log_prob + self.target_entropy)
Expand Down

0 comments on commit d3e0bb1

Please sign in to comment.