diff --git a/torchrl/objectives/crossq.py b/torchrl/objectives/crossq.py index 2372b9e3163..b79c032ac9e 100644 --- a/torchrl/objectives/crossq.py +++ b/torchrl/objectives/crossq.py @@ -521,12 +521,15 @@ def actor_loss( log_prob = dist.log_prob(a_reparm) td_q = tensordict.select(*self.qvalue_network.in_keys, strict=False) + self.qvalue_network.eval() td_q.set(self.tensor_keys.action, a_reparm) td_q = self._vmap_qnetworkN0( td_q, self._cached_detached_qvalue_params, ) + min_q = td_q.get(self.tensor_keys.state_action_value).min(0)[0].squeeze(-1) + self.qvalue_network.train() if log_prob.shape != min_q.shape: raise RuntimeError( @@ -550,17 +553,6 @@ def qvalue_loss( next_tensordict.set(self.tensor_keys.action, next_action) next_sample_log_prob = next_dist.log_prob(next_action) - # TODO: separate forward pass seems faster than the combined. - # next_state_action_value = self._vmap_qnetworkN0( - # next_tensordict.select(*self.qvalue_network.in_keys, strict=False), - # self.qvalue_network_params, - # ).get(self.tensor_keys.state_action_value) - - # current_state_action_value = self._vmap_qnetworkN0( - # tensordict.select(*self.qvalue_network.in_keys, strict=False), - # self.qvalue_network_params, - # ).get(self.tensor_keys.state_action_value) - combined = torch.cat( [ tensordict.select(*self.qvalue_network.in_keys, strict=False),