Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 10, 2025
1 parent 86ab9b7 commit 14e639d
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 25 deletions.
6 changes: 6 additions & 0 deletions examples/agents/composite_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,9 @@ def forward(self, x):
data = TensorDict({"x": torch.rand(10)}, [])
module(data)
print(actor(data))


# TODO:
# 1. Use ("action", "action0") + ("action", "action1") vs ("agent0", "action") + ("agent1", "action")
# 2. Must multi-head require an action_key to be a list of keys (I guess so)
# 3. Using maps in the Actor
60 changes: 52 additions & 8 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7908,11 +7908,11 @@ def _create_mock_actor(
obs_dim=3,
action_dim=4,
device="cpu",
action_key="action",
action_key=None,
observation_key="observation",
sample_log_prob_key="sample_log_prob",
composite_action_dist=False,
aggregate_probabilities=True,
aggregate_probabilities=None,
):
# Actor
action_spec = Bounded(
Expand All @@ -7922,13 +7922,17 @@ def _create_mock_actor(
action_spec = Composite({action_key: {"action1": action_spec}})
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
if composite_action_dist:
if action_key is None:
action_key = ("action", "action1")
else:
action_key = (action_key, "action1")
distribution_class = functools.partial(
CompositeDistribution,
distribution_map={
"action1": TanhNormal,
},
name_map={
"action1": (action_key, "action1"),
"action1": action_key,
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=aggregate_probabilities,
Expand All @@ -7939,6 +7943,8 @@ def _create_mock_actor(
]
actor_in_keys = ["params"]
else:
if action_key is None:
action_key = "action"
distribution_class = TanhNormal
module_out_keys = actor_in_keys = ["loc", "scale"]
module = TensorDictModule(
Expand Down Expand Up @@ -8149,8 +8155,8 @@ def _create_seq_mock_data_ppo(
action_dim=4,
atoms=None,
device="cpu",
sample_log_prob_key="sample_log_prob",
action_key="action",
sample_log_prob_key=None,
action_key=None,
composite_action_dist=False,
):
# create a tensordict
Expand All @@ -8172,6 +8178,17 @@ def _create_seq_mock_data_ppo(
params_scale = torch.rand_like(action) / 10
loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0)
scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0)
if sample_log_prob_key is None:
if composite_action_dist:
sample_log_prob_key = ("action", "action1_log_prob")
else:
sample_log_prob_key = "sample_log_prob"

if action_key is None:
if composite_action_dist:
action_key = ("action", "action1")
else:
action_key = "action"
td = TensorDict(
batch_size=(batch, T),
source={
Expand All @@ -8183,7 +8200,7 @@ def _create_seq_mock_data_ppo(
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
"collector": {"mask": mask},
action_key: {"action1": action} if composite_action_dist else action,
action_key: action,
sample_log_prob_key: (
torch.randn_like(action[..., 1]) / 10
).masked_fill_(~mask, 0.0),
Expand Down Expand Up @@ -8263,6 +8280,13 @@ def test_ppo(
loss_critic_type="l2",
functional=functional,
)
if composite_action_dist:
loss_fn.set_keys(
action=("action", "action1"),
sample_log_prob=[("action", "action1_log_prob")],
)
if advantage is not None:
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
if advantage is not None:
advantage(td)
else:
Expand Down Expand Up @@ -8356,7 +8380,9 @@ def test_ppo_composite_no_aggregate(
loss_critic_type="l2",
functional=functional,
)
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
if advantage is not None:
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
advantage(td)
else:
if td_est is not None:
Expand Down Expand Up @@ -8464,7 +8490,12 @@ def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist):
)

if advantage is not None:
if composite_action_dist:
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
advantage(td)

if composite_action_dist:
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
loss = loss_fn(td)

loss_critic = loss["loss_critic"]
Expand Down Expand Up @@ -8571,7 +8602,14 @@ def test_ppo_shared_seq(
)

if advantage is not None:
if composite_action_dist:
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
advantage(td)

if composite_action_dist:
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
loss_fn2.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])

loss = loss_fn(td).exclude("entropy")

sum(val for key, val in loss.items() if key.startswith("loss_")).backward()
Expand Down Expand Up @@ -8659,7 +8697,11 @@ def zero_param(p):
# assert len(list(floss_fn.parameters())) == 0
with params.to_module(loss_fn):
if advantage is not None:
if composite_action_dist:
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
advantage(td)
if composite_action_dist:
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
loss = loss_fn(td)

loss_critic = loss["loss_critic"]
Expand Down Expand Up @@ -8760,8 +8802,8 @@ def test_ppo_tensordict_keys_run(
"advantage": "advantage_test",
"value_target": "value_target_test",
"value": "state_value_test",
"sample_log_prob": "sample_log_prob_test",
"action": "action_test",
"sample_log_prob": ('action_test', 'action1_log_prob') if composite_action_dist else "sample_log_prob_test",
"action": ("action_test", "action") if composite_action_dist else "action_test",
}

td = self._create_seq_mock_data_ppo(
Expand Down Expand Up @@ -8809,6 +8851,8 @@ def test_ppo_tensordict_keys_run(
raise NotImplementedError

loss_fn = loss_class(actor, value, loss_critic_type="l2")
if composite_action_dist:
tensor_keys["sample_log_prob"] = [tensor_keys["sample_log_prob"]]
loss_fn.set_keys(**tensor_keys)
if advantage is not None:
# collect tensordict key names for the advantage module
Expand Down
45 changes: 28 additions & 17 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import contextlib
import warnings

from copy import deepcopy
from dataclasses import dataclass
Expand Down Expand Up @@ -531,26 +532,35 @@ def _log_weight(
raise RuntimeError(
f"tensordict stored {self.tensor_keys.action} requires grad."
)
if isinstance(action, torch.Tensor):
if isinstance(dist, CompositeDistribution):
is_composite = True
aggregate = dist.aggregate_probabilities
if aggregate is None:
aggregate = False
include_sum = dist.include_sum
if include_sum is None:
include_sum = False
kwargs = {
"inplace": False,
"aggregate_probabilities": aggregate,
"include_sum": include_sum,
}
else:
is_composite = False
kwargs = {}
if not is_composite:
log_prob = dist.log_prob(action)
else:
if isinstance(dist, CompositeDistribution):
is_composite = True
aggregate = dist.aggregate_probabilities
if aggregate is None:
aggregate = False
include_sum = dist.include_sum
if include_sum is None:
include_sum = False
kwargs = {
"inplace": False,
"aggregate_probabilities": aggregate,
"include_sum": include_sum,
}
else:
is_composite = False
kwargs = {}
log_prob: TensorDictBase = dist.log_prob(tensordict, **kwargs)
if not is_tensor_collection(prev_log_prob):
# this isn't great, in general multihead actions should have a composite log-prob too
warnings.warn(
"You are using a composite distribution, yet your log-probability is a tensor. "
"This usually happens whenever the CompositeDistribution has aggregate_probabilities=True "
"or include_sum=True. These options should be avoided: leaf log-probs should be written "
"independently and PPO will take care of the aggregation.",
category=UserWarning,
)
if (
is_composite
and not is_tensor_collection(prev_log_prob)
Expand All @@ -559,6 +569,7 @@ def _log_weight(
log_prob = _sum_td_features(log_prob)
log_prob.view_as(prev_log_prob)

print(log_prob , prev_log_prob)
log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
if is_tensor_collection(kl_approx):
Expand Down

0 comments on commit 14e639d

Please sign in to comment.