Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 25, 2025
2 parents 55934fb + 6e83807 commit 1024d61
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 13 deletions.
5 changes: 4 additions & 1 deletion test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -8409,7 +8409,6 @@ def test_ppo_composite_no_aggregate(
if isinstance(loss_fn, KLPENPPOLoss):
kl = loss.pop("kl_approx")
assert (kl != 0).any()

loss_critic = loss["loss_critic"]
loss_objective = loss["loss_objective"] + loss.get("loss_entropy", 0.0)
loss_critic.backward(retain_graph=True)
Expand Down Expand Up @@ -8637,12 +8636,16 @@ def test_ppo_shared_seq(
)

loss = loss_fn(td).exclude("entropy")
if composite_action_dist:
loss = loss.exclude("composite_entropy")

sum(val for key, val in loss.items() if key.startswith("loss_")).backward()
grad = TensorDict(dict(model.named_parameters()), []).apply(
lambda x: x.grad.clone()
)
loss2 = loss_fn2(td).exclude("entropy")
if composite_action_dist:
loss2 = loss2.exclude("composite_entropy")

model.zero_grad()
sum(val for key, val in loss2.items() if key.startswith("loss_")).backward()
Expand Down
79 changes: 79 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,6 +689,85 @@ def test_single_trans_env_check(self, out_keys):
)
check_env_specs(env)

@pytest.mark.parametrize("cat_dim", [-1, -2, -3])
@pytest.mark.parametrize("cat_N", [3, 10])
@pytest.mark.parametrize("device", get_default_devices())
def test_with_permute_no_env(self, cat_dim, cat_N, device):
torch.manual_seed(cat_dim * cat_N)
pixels = torch.randn(8, 5, 3, 10, 4, device=device)

a = TensorDict(
{
"pixels": pixels,
},
[
pixels.shape[0],
],
device=device,
)

t0 = Compose(
CatFrames(N=cat_N, dim=cat_dim),
)

def get_rand_perm(ndim):
cat_dim_perm = cat_dim
# Ensure that the permutation moves the cat_dim
while cat_dim_perm == cat_dim:
perm_pos = torch.randperm(ndim)
perm = perm_pos - ndim
cat_dim_perm = (perm == cat_dim).nonzero().item() - ndim
perm_inv = perm_pos.argsort() - ndim
return perm.tolist(), perm_inv.tolist(), cat_dim_perm

perm, perm_inv, cat_dim_perm = get_rand_perm(pixels.dim() - 1)

t1 = Compose(
PermuteTransform(perm, in_keys=["pixels"]),
CatFrames(N=cat_N, dim=cat_dim_perm),
PermuteTransform(perm_inv, in_keys=["pixels"]),
)

b = t0._call(a.clone())
c = t1._call(a.clone())
assert (b == c).all()

@pytest.mark.skipif(not _has_gym, reason="Test executed on gym")
@pytest.mark.parametrize("cat_dim", [-1, -2])
def test_with_permute_env(self, cat_dim):
env0 = TransformedEnv(
GymEnv("Pendulum-v1"),
Compose(
UnsqueezeTransform(-1, in_keys=["observation"]),
CatFrames(N=4, dim=cat_dim, in_keys=["observation"]),
),
)

env1 = TransformedEnv(
GymEnv("Pendulum-v1"),
Compose(
UnsqueezeTransform(-1, in_keys=["observation"]),
PermuteTransform((-1, -2), in_keys=["observation"]),
CatFrames(N=4, dim=-3 - cat_dim, in_keys=["observation"]),
PermuteTransform((-1, -2), in_keys=["observation"]),
),
)

torch.manual_seed(0)
env0.set_seed(0)
td0 = env0.reset()

torch.manual_seed(0)
env1.set_seed(0)
td1 = env1.reset()

assert (td0 == td1).all()

td0 = env0.step(td0.update(env0.full_action_spec.rand()))
td1 = env0.step(td0.update(env1.full_action_spec.rand()))

assert (td0 == td1).all()

def test_serial_trans_env_check(self):
env = SerialEnv(
2,
Expand Down
34 changes: 22 additions & 12 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,7 +494,7 @@ def _forward_value_estimator_keys(self, **kwargs) -> None:
def reset(self) -> None:
pass

def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
def _get_entropy(self, dist: d.Distribution) -> torch.Tensor | TensorDict:
try:
entropy = dist.entropy()
except NotImplementedError:
Expand All @@ -513,13 +513,11 @@ def get_entropy_bonus(self, dist: d.Distribution) -> torch.Tensor:
log_prob = log_prob.select(*self.tensor_keys.sample_log_prob)

entropy = -log_prob.mean(0)
if is_tensor_collection(entropy):
entropy = _sum_td_features(entropy)
return entropy.unsqueeze(-1)

def _log_weight(
self, tensordict: TensorDictBase
) -> Tuple[torch.Tensor, d.Distribution]:
) -> Tuple[torch.Tensor, d.Distribution, torch.Tensor]:

with self.actor_network_params.to_module(
self.actor_network
Expand Down Expand Up @@ -681,10 +679,14 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
log_weight = log_weight.view(advantage.shape)
neg_loss = log_weight.exp() * advantage
td_out = TensorDict({"loss_objective": -neg_loss}, batch_size=[])
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
entropy = self._get_entropy(dist)
if is_tensor_collection(entropy):
# Reports the entropy of each action head.
td_out.set("composite_entropy", entropy.detach())
entropy = _sum_td_features(entropy)
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
if self.critic_coef is not None:
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
Expand Down Expand Up @@ -956,8 +958,8 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
# ESS for logging
with torch.no_grad():
# In theory, ESS should be computed on particles sampled from the same source. Here we sample according
# to different, unrelated trajectories, which is not standard. Still it can give a idea of the dispersion
# of the weights.
# to different, unrelated trajectories, which is not standard. Still, it can give an idea of the weights'
# dispersion.
lw = log_weight.squeeze()
if not isinstance(lw, torch.Tensor):
lw = _sum_td_features(lw)
Expand All @@ -976,11 +978,15 @@ def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
gain = _sum_td_features(gain)
td_out = TensorDict({"loss_objective": -gain}, batch_size=[])
td_out.set("clip_fraction", clip_fraction)
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging

if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
entropy = self._get_entropy(dist)
if is_tensor_collection(entropy):
# Reports the entropy of each action head.
td_out.set("composite_entropy", entropy.detach())
entropy = _sum_td_features(entropy)
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
if self.critic_coef is not None:
loss_critic, value_clip_fraction = self.loss_critic(tensordict)
Expand Down Expand Up @@ -1282,14 +1288,18 @@ def forward(self, tensordict: TensorDictBase) -> TensorDict:
{
"loss_objective": -neg_loss,
"kl": kl.detach(),
"kl_approx": kl_approx.detach().mean(),
},
batch_size=[],
)

if self.entropy_bonus:
entropy = self.get_entropy_bonus(dist)
entropy = self._get_entropy(dist)
if is_tensor_collection(entropy):
# Reports the entropy of each action head.
td_out.set("composite_entropy", entropy.detach())
entropy = _sum_td_features(entropy)
td_out.set("entropy", entropy.detach().mean()) # for logging
td_out.set("kl_approx", kl_approx.detach().mean()) # for logging
td_out.set("loss_entropy", -self.entropy_coef * entropy)
if self.critic_coef is not None:
loss_critic, value_clip_fraction = self.loss_critic(tensordict_copy)
Expand Down

0 comments on commit 1024d61

Please sign in to comment.