Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
vmoens committed Jan 10, 2025
2 parents c10adf1 + 6144a91 commit 443330b
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 8 deletions.
6 changes: 6 additions & 0 deletions examples/agents/composite_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,9 @@ def forward(self, x):
data = TensorDict({"x": torch.rand(10)}, [])
module(data)
print(actor(data))


# TODO:
# 1. Use ("action", "action0") + ("action", "action1") vs ("agent0", "action") + ("agent1", "action")
# 2. Must multi-head require an action_key to be a list of keys (I guess so)
# 3. Using maps in the Actor
60 changes: 52 additions & 8 deletions test/test_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -7908,11 +7908,11 @@ def _create_mock_actor(
obs_dim=3,
action_dim=4,
device="cpu",
action_key="action",
action_key=None,
observation_key="observation",
sample_log_prob_key="sample_log_prob",
composite_action_dist=False,
aggregate_probabilities=True,
aggregate_probabilities=None,
):
# Actor
action_spec = Bounded(
Expand All @@ -7922,13 +7922,17 @@ def _create_mock_actor(
action_spec = Composite({action_key: {"action1": action_spec}})
net = nn.Sequential(nn.Linear(obs_dim, 2 * action_dim), NormalParamExtractor())
if composite_action_dist:
if action_key is None:
action_key = ("action", "action1")
else:
action_key = (action_key, "action1")
distribution_class = functools.partial(
CompositeDistribution,
distribution_map={
"action1": TanhNormal,
},
name_map={
"action1": (action_key, "action1"),
"action1": action_key,
},
log_prob_key=sample_log_prob_key,
aggregate_probabilities=aggregate_probabilities,
Expand All @@ -7939,6 +7943,8 @@ def _create_mock_actor(
]
actor_in_keys = ["params"]
else:
if action_key is None:
action_key = "action"
distribution_class = TanhNormal
module_out_keys = actor_in_keys = ["loc", "scale"]
module = TensorDictModule(
Expand Down Expand Up @@ -8149,8 +8155,8 @@ def _create_seq_mock_data_ppo(
action_dim=4,
atoms=None,
device="cpu",
sample_log_prob_key="sample_log_prob",
action_key="action",
sample_log_prob_key=None,
action_key=None,
composite_action_dist=False,
):
# create a tensordict
Expand All @@ -8173,6 +8179,17 @@ def _create_seq_mock_data_ppo(
action = action.masked_fill_(~mask.unsqueeze(-1), 0.0)
loc = params_mean.masked_fill_(~mask.unsqueeze(-1), 0.0)
scale = params_scale.masked_fill_(~mask.unsqueeze(-1), 0.0)
if sample_log_prob_key is None:
if composite_action_dist:
sample_log_prob_key = ("action", "action1_log_prob")
else:
sample_log_prob_key = "sample_log_prob"

if action_key is None:
if composite_action_dist:
action_key = ("action", "action1")
else:
action_key = "action"
td = TensorDict(
batch_size=(batch, T),
source={
Expand All @@ -8184,7 +8201,7 @@ def _create_seq_mock_data_ppo(
"reward": reward.masked_fill_(~mask.unsqueeze(-1), 0.0),
},
"collector": {"mask": mask},
action_key: {"action1": action} if composite_action_dist else action,
action_key: action,
},
device=device,
names=[None, "time"],
Expand Down Expand Up @@ -8261,6 +8278,13 @@ def test_ppo(
loss_critic_type="l2",
functional=functional,
)
if composite_action_dist:
loss_fn.set_keys(
action=("action", "action1"),
sample_log_prob=[("action", "action1_log_prob")],
)
if advantage is not None:
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
if advantage is not None:
advantage(td)
else:
Expand Down Expand Up @@ -8354,7 +8378,9 @@ def test_ppo_composite_no_aggregate(
loss_critic_type="l2",
functional=functional,
)
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
if advantage is not None:
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
advantage(td)
else:
if td_est is not None:
Expand Down Expand Up @@ -8462,7 +8488,12 @@ def test_ppo_shared(self, loss_class, device, advantage, composite_action_dist):
)

if advantage is not None:
if composite_action_dist:
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
advantage(td)

if composite_action_dist:
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
loss = loss_fn(td)

loss_critic = loss["loss_critic"]
Expand Down Expand Up @@ -8569,7 +8600,14 @@ def test_ppo_shared_seq(
)

if advantage is not None:
if composite_action_dist:
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
advantage(td)

if composite_action_dist:
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
loss_fn2.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])

loss = loss_fn(td).exclude("entropy")

sum(val for key, val in loss.items() if key.startswith("loss_")).backward()
Expand Down Expand Up @@ -8657,7 +8695,11 @@ def zero_param(p):
# assert len(list(floss_fn.parameters())) == 0
with params.to_module(loss_fn):
if advantage is not None:
if composite_action_dist:
advantage.set_keys(sample_log_prob=[("action", "action1_log_prob")])
advantage(td)
if composite_action_dist:
loss_fn.set_keys(action=("action", "action1"), sample_log_prob=[("action", "action1_log_prob")])
loss = loss_fn(td)

loss_critic = loss["loss_critic"]
Expand Down Expand Up @@ -8758,8 +8800,8 @@ def test_ppo_tensordict_keys_run(
"advantage": "advantage_test",
"value_target": "value_target_test",
"value": "state_value_test",
"sample_log_prob": "sample_log_prob_test",
"action": "action_test",
"sample_log_prob": ('action_test', 'action1_log_prob') if composite_action_dist else "sample_log_prob_test",
"action": ("action_test", "action") if composite_action_dist else "action_test",
}

td = self._create_seq_mock_data_ppo(
Expand Down Expand Up @@ -8807,6 +8849,8 @@ def test_ppo_tensordict_keys_run(
raise NotImplementedError

loss_fn = loss_class(actor, value, loss_critic_type="l2")
if composite_action_dist:
tensor_keys["sample_log_prob"] = [tensor_keys["sample_log_prob"]]
loss_fn.set_keys(**tensor_keys)
if advantage is not None:
# collect tensordict key names for the advantage module
Expand Down
2 changes: 2 additions & 0 deletions torchrl/objectives/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import contextlib
import warnings

from copy import deepcopy
from dataclasses import dataclass
Expand Down Expand Up @@ -556,6 +557,7 @@ def _log_weight(
log_prob = _sum_td_features(log_prob)
log_prob.view_as(prev_log_prob)

print(log_prob , prev_log_prob)
log_weight = (log_prob - prev_log_prob).unsqueeze(-1)
kl_approx = (prev_log_prob - log_prob).unsqueeze(-1)
if is_tensor_collection(kl_approx):
Expand Down

0 comments on commit 443330b

Please sign in to comment.