From f592f54155fb9380aa35960d73b205087ac4b178 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 25 Jun 2024 18:01:59 +0200 Subject: [PATCH 01/18] init td3 bc --- sota-implementations/td3+bc/config.yaml | 43 ++ sota-implementations/td3+bc/td3+bc.py | 151 +++++++ sota-implementations/td3+bc/utils.py | 273 ++++++++++++ torchrl/objectives/__init__.py | 1 + torchrl/objectives/td3_bc.py | 534 ++++++++++++++++++++++++ 5 files changed, 1002 insertions(+) create mode 100644 sota-implementations/td3+bc/config.yaml create mode 100644 sota-implementations/td3+bc/td3+bc.py create mode 100644 sota-implementations/td3+bc/utils.py create mode 100644 torchrl/objectives/td3_bc.py diff --git a/sota-implementations/td3+bc/config.yaml b/sota-implementations/td3+bc/config.yaml new file mode 100644 index 00000000000..bc8816257e7 --- /dev/null +++ b/sota-implementations/td3+bc/config.yaml @@ -0,0 +1,43 @@ +# task and env +env: + name: HalfCheetah-v4 # Use v4 to get rid of mujoco-py dependency + task: "" + library: gym + seed: 42 + max_episode_steps: 1000 + +# replay buffer +replay_buffer: + dataset: halfcheetah-medium-v2 + batch_size: 256 + +# optim +optim: + gradient_steps: 100000 + gamma: 0.99 + loss_function: l2 + lr: 3.0e-4 + weight_decay: 0.0 + adam_eps: 1e-4 + batch_size: 256 + target_update_polyak: 0.995 + policy_update_delay: 2 + policy_noise: 0.2 + noise_clip: 0.5 + alpha: 2.5 + +# network +network: + hidden_sizes: [256, 256] + activation: relu + device: null + +# logging +logger: + backend: wandb + project_name: td3+bc_${replay_buffer.dataset} + group_name: null + exp_name: TD3+BC_${replay_buffer.dataset} + mode: online + eval_iter: 5000 + video: False diff --git a/sota-implementations/td3+bc/td3+bc.py b/sota-implementations/td3+bc/td3+bc.py new file mode 100644 index 00000000000..65e28b3c6db --- /dev/null +++ b/sota-implementations/td3+bc/td3+bc.py @@ -0,0 +1,151 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""IQL Example. + +This is a self-contained example of an offline IQL training script. + +The helper functions are coded in the utils.py associated with this script. + +""" +import time + +import hydra +import numpy as np +import torch +import tqdm +from torchrl._utils import logger as torchrl_logger + +from torchrl.envs import set_gym_backend +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.record.loggers import generate_exp_name, get_logger + +from utils import ( + dump_video, + log_metrics, + make_environment, + make_loss_module, + make_offline_replay_buffer, + make_optimizer, + make_td3_agent, +) + + +@hydra.main(config_path="", config_name="offline_config") +def main(cfg: "DictConfig"): # noqa: F821 + set_gym_backend(cfg.env.backend).set() + + # Create logger + exp_name = generate_exp_name("TD3BC-offline", cfg.logger.exp_name) + logger = None + if cfg.logger.backend: + logger = get_logger( + logger_type=cfg.logger.backend, + logger_name="td3bc_logging", + experiment_name=exp_name, + wandb_kwargs={ + "mode": cfg.logger.mode, + "config": dict(cfg), + "project": cfg.logger.project_name, + "group": cfg.logger.group_name, + }, + ) + + # Set seeds + torch.manual_seed(cfg.env.seed) + np.random.seed(cfg.env.seed) + device = cfg.optim.device + if device in ("", None): + if torch.cuda.is_available(): + device = "cuda:0" + else: + device = "cpu" + device = torch.device(device) + + # Creante env + train_env, eval_env = make_environment( + cfg, + cfg.logger.eval_envs, + logger=logger, + ) + + # Create replay buffer + replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) + + # Create agent + model, actor_model_explore = make_td3_agent(cfg, train_env, eval_env, device) + + # Create loss + loss_module, target_net_updater = make_loss_module(cfg.loss, model) + + # Create optimizer + optimizer_actor, optimizer_critic = make_optimizer(cfg.optim, loss_module) + + pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) + + gradient_steps = cfg.optim.gradient_steps + evaluation_interval = cfg.logger.eval_iter + eval_steps = cfg.logger.eval_steps + delayed_updates = cfg.optim.policy_update_delay + update_counter = 0 + # Training loop + start_time = time.time() + for i in range(gradient_steps): + pbar.update(1) + # Update actor every delayed_updates + update_counter += 1 + update_actor = update_counter % delayed_updates == 0 + + # Sample from replay buffer + sampled_tensordict = replay_buffer.sample() + if sampled_tensordict.device != device: + sampled_tensordict = sampled_tensordict.to(device, non_blocking=True) + else: + sampled_tensordict = sampled_tensordict.clone() + + # Compute loss + q_loss, *_ = loss_module.value_loss(sampled_tensordict) + + # Update critic + optimizer_critic.zero_grad() + q_loss.backward() + optimizer_critic.step() + q_loss.item() + + # Update actor + if update_actor: + actor_loss, actorloss_metadata = loss_module.actor_loss(sampled_tensordict) + optimizer_actor.zero_grad() + actor_loss.backward() + optimizer_actor.step() + + # Update target params + target_net_updater.step() + + # log metrics + to_log = { + "loss_actor": actor_loss.item(), + "loss_qvalue": q_loss.item(), + "bc_loss": actorloss_metadata.bc_loss.item(), + "lambda": actorloss_metadata.actor_loss.item(), + } + + # evaluation + if i % evaluation_interval == 0: + with set_exploration_type(ExplorationType.MODE), torch.no_grad(): + eval_td = eval_env.rollout( + max_steps=eval_steps, policy=model[0], auto_cast_to_device=True + ) + eval_env.apply(dump_video) + eval_reward = eval_td["next", "reward"].sum(1).mean().item() + to_log["evaluation_reward"] = eval_reward + if logger is not None: + log_metrics(logger, to_log, i) + + pbar.close() + torchrl_logger.info(f"Training time: {time.time() - start_time}") + + +if __name__ == "__main__": + main() diff --git a/sota-implementations/td3+bc/utils.py b/sota-implementations/td3+bc/utils.py new file mode 100644 index 00000000000..498b90fbf58 --- /dev/null +++ b/sota-implementations/td3+bc/utils.py @@ -0,0 +1,273 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import functools + +import torch + +from torch import nn, optim +from torchrl.data.datasets.d4rl import D4RLExperienceReplay +from torchrl.data.replay_buffers import SamplerWithoutReplacement +from torchrl.envs import ( + CatTensors, + Compose, + DMControlEnv, + DoubleToFloat, + EnvCreator, + InitTracker, + ParallelEnv, + RewardSum, + StepCounter, + TransformedEnv, +) +from torchrl.envs.libs.gym import GymEnv, set_gym_backend +from torchrl.envs.utils import ExplorationType, set_exploration_type +from torchrl.modules import ( + AdditiveGaussianWrapper, + MLP, + SafeModule, + SafeSequential, + TanhModule, + ValueOperator, +) + +from torchrl.objectives import SoftUpdate +from torchrl.objectives.td3_bc import TD3BCLoss +from torchrl.record import VideoRecorder + + +# ==================================================================== +# Environment utils +# ----------------- + + +def env_maker(cfg, device="cpu", from_pixels=False): + lib = cfg.env.library + if lib in ("gym", "gymnasium"): + with set_gym_backend(lib): + return GymEnv( + cfg.env.name, + device=device, + from_pixels=from_pixels, + pixels_only=False, + ) + elif lib == "dm_control": + env = DMControlEnv( + cfg.env.name, cfg.env.task, from_pixels=from_pixels, pixels_only=False + ) + return TransformedEnv( + env, CatTensors(in_keys=env.observation_spec.keys(), out_key="observation") + ) + else: + raise NotImplementedError(f"Unknown lib {lib}.") + + +def apply_env_transforms(env, max_episode_steps): + transformed_env = TransformedEnv( + env, + Compose( + StepCounter(max_steps=max_episode_steps), + InitTracker(), + DoubleToFloat(), + RewardSum(), + ), + ) + return transformed_env + + +def make_environment(cfg, logger=None): + """Make environments for training and evaluation.""" + partial = functools.partial(env_maker, cfg=cfg) + parallel_env = ParallelEnv( + cfg.collector.env_per_collector, + EnvCreator(partial), + serial_for_single=True, + ) + parallel_env.set_seed(cfg.env.seed) + + train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps) + + partial = functools.partial(env_maker, cfg=cfg, from_pixels=cfg.logger.video) + trsf_clone = train_env.transform.clone() + if cfg.logger.video: + trsf_clone.insert( + 0, VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) + ) + eval_env = TransformedEnv( + ParallelEnv( + cfg.collector.env_per_collector, + EnvCreator(partial), + serial_for_single=True, + ), + trsf_clone, + ) + return train_env, eval_env + + +# ==================================================================== +# Replay buffer +# --------------------------- + + +def make_offline_replay_buffer(rb_cfg): + data = D4RLExperienceReplay( + dataset_id=rb_cfg.dataset, + split_trajs=False, + batch_size=rb_cfg.batch_size, + sampler=SamplerWithoutReplacement(drop_last=False), + prefetch=4, + direct_download=True, + ) + + data.append_transform(DoubleToFloat()) + + return data + + +# ==================================================================== +# Model +# ----- + + +def make_td3_agent(cfg, train_env, eval_env, device): + """Make TD3 agent.""" + # Define Actor Network + in_keys = ["observation"] + action_spec = train_env.action_spec + if train_env.batch_size: + action_spec = action_spec[(0,) * len(train_env.batch_size)] + actor_net_kwargs = { + "num_cells": cfg.network.hidden_sizes, + "out_features": action_spec.shape[-1], + "activation_class": get_activation(cfg), + } + + actor_net = MLP(**actor_net_kwargs) + + in_keys_actor = in_keys + actor_module = SafeModule( + actor_net, + in_keys=in_keys_actor, + out_keys=[ + "param", + ], + ) + actor = SafeSequential( + actor_module, + TanhModule( + in_keys=["param"], + out_keys=["action"], + spec=action_spec, + ), + ) + + # Define Critic Network + qvalue_net_kwargs = { + "num_cells": cfg.network.hidden_sizes, + "out_features": 1, + "activation_class": get_activation(cfg), + } + + qvalue_net = MLP( + **qvalue_net_kwargs, + ) + + qvalue = ValueOperator( + in_keys=["action"] + in_keys, + module=qvalue_net, + ) + + model = nn.ModuleList([actor, qvalue]).to(device) + + # init nets + with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): + td = eval_env.reset() + td = td.to(device) + for net in model: + net(td) + del td + eval_env.close() + + # Exploration wrappers: + actor_model_explore = AdditiveGaussianWrapper( + model[0], + sigma_init=1, + sigma_end=1, + mean=0, + std=0.1, + spec=action_spec, + ).to(device) + return model, actor_model_explore + + +# ==================================================================== +# TD3 Loss +# --------- + + +def make_loss_module(cfg, model): + """Make loss module and target network updater.""" + # Create TD3 loss + loss_module = TD3BCLoss( + actor_network=model[0], + qvalue_network=model[1], + num_qvalue_nets=2, + loss_function=cfg.optim.loss_function, + delay_actor=True, + delay_qvalue=True, + action_spec=model[0][1].spec, + policy_noise=cfg.optim.policy_noise, + noise_clip=cfg.optim.noise_clip, + alpha=cfg.optim.alpha, + ) + loss_module.make_value_estimator(gamma=cfg.optim.gamma) + + # Define Target Network Updater + target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak) + return loss_module, target_net_updater + + +def make_optimizer(cfg, loss_module): + critic_params = list(loss_module.qvalue_network_params.flatten_keys().values()) + actor_params = list(loss_module.actor_network_params.flatten_keys().values()) + + optimizer_actor = optim.Adam( + actor_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, + ) + optimizer_critic = optim.Adam( + critic_params, + lr=cfg.optim.lr, + weight_decay=cfg.optim.weight_decay, + eps=cfg.optim.adam_eps, + ) + return optimizer_actor, optimizer_critic + + +# ==================================================================== +# General utils +# --------- + + +def log_metrics(logger, metrics, step): + for metric_name, metric_value in metrics.items(): + logger.log_scalar(metric_name, metric_value, step) + + +def get_activation(cfg): + if cfg.network.activation == "relu": + return nn.ReLU + elif cfg.network.activation == "tanh": + return nn.Tanh + elif cfg.network.activation == "leaky_relu": + return nn.LeakyReLU + else: + raise NotImplementedError + + +def dump_video(module): + if isinstance(module, VideoRecorder): + module.dump() diff --git a/torchrl/objectives/__init__.py b/torchrl/objectives/__init__.py index f8d2bd1d977..674c06123ad 100644 --- a/torchrl/objectives/__init__.py +++ b/torchrl/objectives/__init__.py @@ -17,6 +17,7 @@ from .reinforce import ReinforceLoss from .sac import DiscreteSACLoss, SACLoss from .td3 import TD3Loss +from .td3_bc import TD3BCLoss from .utils import ( default_value_kwargs, distance_loss, diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py new file mode 100644 index 00000000000..62d2e4f2585 --- /dev/null +++ b/torchrl/objectives/td3_bc.py @@ -0,0 +1,534 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch + +from tensordict import TensorDict, TensorDictBase, TensorDictParams +from tensordict.nn import dispatch, TensorDictModule +from tensordict.utils import NestedKey +from torchrl.data.tensor_specs import BoundedTensorSpec, CompositeSpec, TensorSpec + +from torchrl.envs.utils import step_mdp +from torchrl.objectives.common import LossModule + +from torchrl.objectives.utils import ( + _cache_values, + _GAMMA_LMBDA_DEPREC_ERROR, + _reduce, + _vmap_func, + default_value_kwargs, + distance_loss, + ValueEstimators, +) +from torchrl.objectives.value import TD0Estimator, TD1Estimator, TDLambdaEstimator + + +class TD3BCLoss(LossModule): + r"""TD3+BC Loss Module. + + Implementation of the TD3+BC loss presented in the paper `"A Minimalist Approach to + Offline Reinforcement Learning" ` + + Args: + actor_network (TensorDictModule): the actor to be trained + qvalue_network (TensorDictModule): a single Q-value network that will + be multiplicated as many times as needed. + + Keyword Args: + bounds (tuple of float, optional): the bounds of the action space. + Exclusive with action_spec. Either this or ``action_spec`` must + be provided. + action_spec (TensorSpec, optional): the action spec. + Exclusive with bounds. Either this or ``bounds`` must be provided. + num_qvalue_nets (int, optional): Number of Q-value networks to be + trained. Default is ``10``. + policy_noise (float, optional): Standard deviation for the target + policy action noise. Default is ``0.2``. + noise_clip (float, optional): Clipping range value for the sampled + target policy action noise. Default is ``0.5``. + alpha (float, optional): Weight for the behavioral cloning loss. + priority_key (str, optional): Key where to write the priority value + for prioritized replay buffers. Default is + `"td_error"`. + loss_function (str, optional): loss function to be used for the Q-value. + Can be one of ``"smooth_l1"``, ``"l2"``, + ``"l1"``, Default is ``"smooth_l1"``. + delay_actor (bool, optional): whether to separate the target actor + networks from the actor networks used for + data collection. Default is ``True``. + delay_qvalue (bool, optional): Whether to separate the target Q value + networks from the Q value networks used + for data collection. Default is ``True``. + spec (TensorSpec, optional): the action tensor spec. If not provided + and the target entropy is ``"auto"``, it will be retrieved from + the actor. + separate_losses (bool, optional): if ``True``, shared parameters between + policy and critic will only be trained on the policy loss. + Defaults to ``False``, ie. gradients are propagated to shared + parameters for both policy and critic losses. + reduction (str, optional): Specifies the reduction to apply to the output: + ``"none"`` | ``"mean"`` | ``"sum"``. ``"none"``: no reduction will be applied, + ``"mean"``: the sum of the output will be divided by the number of + elements in the output, ``"sum"``: the output will be summed. Default: ``"mean"``. + + Examples: + >>> import torch + >>> from torch import nn + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal + >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator + >>> from torchrl.modules.tensordict_module.common import SafeModule + >>> from torchrl.objectives.td3 import TD3Loss + >>> from tensordict import TensorDict + >>> n_act, n_obs = 4, 3 + >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> module = nn.Linear(n_obs, n_act) + >>> actor = Actor( + ... module=module, + ... spec=spec) + >>> class ValueClass(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = nn.Linear(n_obs + n_act, 1) + ... def forward(self, obs, act): + ... return self.linear(torch.cat([obs, act], -1)) + >>> module = ValueClass() + >>> qvalue = ValueOperator( + ... module=module, + ... in_keys=['observation', 'action']) + >>> loss = TD3Loss(actor, qvalue, action_spec=actor.spec) + >>> batch = [2, ] + >>> action = spec.rand(batch) + >>> data = TensorDict({ + ... "observation": torch.randn(*batch, n_obs), + ... "action": action, + ... ("next", "done"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "terminated"): torch.zeros(*batch, 1, dtype=torch.bool), + ... ("next", "reward"): torch.randn(*batch, 1), + ... ("next", "observation"): torch.randn(*batch, n_obs), + ... }, batch) + >>> loss(data) + TensorDict( + fields={ + loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + next_state_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + state_action_value_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + batch_size=torch.Size([]), + device=None, + is_shared=False) + + This class is compatible with non-tensordict based modules too and can be + used without recurring to any tensordict-related primitive. In this case, + the expected keyword arguments are: + ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network + The return value is a tuple of tensors in the following order: + ``["loss_actor", "loss_qvalue", "pred_value", "state_action_value_actor", "next_state_value", "target_value",]``. + + Examples: + >>> import torch + >>> from torch import nn + >>> from torchrl.data import BoundedTensorSpec + >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator + >>> from torchrl.objectives.td3 import TD3Loss + >>> n_act, n_obs = 4, 3 + >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) + >>> module = nn.Linear(n_obs, n_act) + >>> actor = Actor( + ... module=module, + ... spec=spec) + >>> class ValueClass(nn.Module): + ... def __init__(self): + ... super().__init__() + ... self.linear = nn.Linear(n_obs + n_act, 1) + ... def forward(self, obs, act): + ... return self.linear(torch.cat([obs, act], -1)) + >>> module = ValueClass() + >>> qvalue = ValueOperator( + ... module=module, + ... in_keys=['observation', 'action']) + >>> loss = TD3Loss(actor, qvalue, action_spec=actor.spec) + >>> _ = loss.select_out_keys("loss_actor", "loss_qvalue") + >>> batch = [2, ] + >>> action = spec.rand(batch) + >>> loss_actor, loss_qvalue = loss( + ... observation=torch.randn(*batch, n_obs), + ... action=action, + ... next_done=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_terminated=torch.zeros(*batch, 1, dtype=torch.bool), + ... next_reward=torch.randn(*batch, 1), + ... next_observation=torch.randn(*batch, n_obs)) + >>> loss_actor.backward() + + """ + + @dataclass + class _AcceptedKeys: + """Maintains default values for all configurable tensordict keys. + + This class defines which tensordict keys can be set using '.set_keys(key_name=key_value)' and their + default values. + + Attributes: + action (NestedKey): The input tensordict key where the action is expected. + Defaults to ``"action"``. + state_action_value (NestedKey): The input tensordict key where the state action value is expected. + Will be used for the underlying value estimator. Defaults to ``"state_action_value"``. + priority (NestedKey): The input tensordict key where the target priority is written to. + Defaults to ``"td_error"``. + reward (NestedKey): The input tensordict key where the reward is expected. + Will be used for the underlying value estimator. Defaults to ``"reward"``. + done (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is done. Will be used for the underlying value estimator. + Defaults to ``"done"``. + terminated (NestedKey): The key in the input TensorDict that indicates + whether a trajectory is terminated. Will be used for the underlying value estimator. + Defaults to ``"terminated"``. + """ + + action: NestedKey = "action" + state_action_value: NestedKey = "state_action_value" + priority: NestedKey = "td_error" + reward: NestedKey = "reward" + done: NestedKey = "done" + terminated: NestedKey = "terminated" + + default_keys = _AcceptedKeys() + default_value_estimator = ValueEstimators.TD0 + out_keys = [ + "loss_actor", + "loss_qvalue", + "pred_value", + "state_action_value_actor", + "next_state_value", + "target_value", + ] + + actor_network: TensorDictModule + qvalue_network: TensorDictModule + actor_network_params: TensorDictParams + qvalue_network_params: TensorDictParams + target_actor_network_params: TensorDictParams + target_qvalue_network_params: TensorDictParams + + def __init__( + self, + actor_network: TensorDictModule, + qvalue_network: TensorDictModule, + *, + action_spec: TensorSpec = None, + bounds: Optional[Tuple[float]] = None, + num_qvalue_nets: int = 2, + policy_noise: float = 0.2, + noise_clip: float = 0.5, + alpha: float = 2.5, + loss_function: str = "smooth_l1", + delay_actor: bool = True, + delay_qvalue: bool = True, + gamma: float = None, + priority_key: str = None, + separate_losses: bool = False, + reduction: str = None, + ) -> None: + if reduction is None: + reduction = "mean" + super().__init__() + self._in_keys = None + self._set_deprecated_ctor_keys(priority=priority_key) + + self.delay_actor = delay_actor + self.delay_qvalue = delay_qvalue + + self.convert_to_functional( + actor_network, + "actor_network", + create_target_params=self.delay_actor, + ) + if separate_losses: + # we want to make sure there are no duplicates in the params: the + # params of critic must be refs to actor if they're shared + policy_params = list(actor_network.parameters()) + else: + policy_params = None + self.convert_to_functional( + qvalue_network, + "qvalue_network", + num_qvalue_nets, + create_target_params=self.delay_qvalue, + compare_against=policy_params, + ) + + for p in self.parameters(): + device = p.device + break + else: + device = None + self.num_qvalue_nets = num_qvalue_nets + self.loss_function = loss_function + self.policy_noise = policy_noise + self.noise_clip = noise_clip + if not ((action_spec is not None) ^ (bounds is not None)): + raise ValueError( + "One of 'bounds' and 'action_spec' must be provided, " + f"but not both or none. Got bounds={bounds} and action_spec={action_spec}." + ) + elif action_spec is not None: + if isinstance(action_spec, CompositeSpec): + if ( + isinstance(self.tensor_keys.action, tuple) + and len(self.tensor_keys.action) > 1 + ): + action_container_shape = action_spec[ + self.tensor_keys.action[:-1] + ].shape + else: + action_container_shape = action_spec.shape + action_spec = action_spec[self.tensor_keys.action][ + (0,) * len(action_container_shape) + ] + if not isinstance(action_spec, BoundedTensorSpec): + raise ValueError( + f"action_spec is not of type BoundedTensorSpec but {type(action_spec)}." + ) + low = action_spec.space.low + high = action_spec.space.high + else: + low, high = bounds + if not isinstance(low, torch.Tensor): + low = torch.tensor(low) + if not isinstance(high, torch.Tensor): + high = torch.tensor(high, device=low.device, dtype=low.dtype) + if (low > high).any(): + raise ValueError("Got a low bound higher than a high bound.") + if device is not None: + low = low.to(device) + high = high.to(device) + self.register_buffer("max_action", high) + self.register_buffer("min_action", low) + if gamma is not None: + raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) + self._vmap_qvalue_network00 = _vmap_func( + self.qvalue_network, randomness=self.vmap_randomness + ) + self._vmap_actor_network00 = _vmap_func( + self.actor_network, randomness=self.vmap_randomness + ) + self.reduction = reduction + + def _forward_value_estimator_keys(self, **kwargs) -> None: + if self._value_estimator is not None: + self._value_estimator.set_keys( + value=self._tensor_keys.state_action_value, + reward=self.tensor_keys.reward, + done=self.tensor_keys.done, + terminated=self.tensor_keys.terminated, + ) + self._set_in_keys() + + def _set_in_keys(self): + keys = [ + self.tensor_keys.action, + ("next", self.tensor_keys.reward), + ("next", self.tensor_keys.done), + ("next", self.tensor_keys.terminated), + *self.actor_network.in_keys, + *[("next", key) for key in self.actor_network.in_keys], + *self.qvalue_network.in_keys, + ] + self._in_keys = list(set(keys)) + + @property + def in_keys(self): + if self._in_keys is None: + self._set_in_keys() + return self._in_keys + + @in_keys.setter + def in_keys(self, values): + self._in_keys = values + + @property + @_cache_values + def _cached_detach_qvalue_network_params(self): + return self.qvalue_network_params.detach() + + @property + @_cache_values + def _cached_stack_actor_params(self): + return torch.stack( + [self.actor_network_params, self.target_actor_network_params], 0 + ) + + def actor_loss(self, tensordict): + tensordict_actor_grad = tensordict.select( + *self.actor_network.in_keys, strict=False + ) + with self.actor_network_params.to_module(self.actor_network): + tensordict_actor_grad = self.actor_network(tensordict_actor_grad) + actor_loss_td = tensordict_actor_grad.select( + *self.qvalue_network.in_keys, strict=False + ).expand( + self.num_qvalue_nets, *tensordict_actor_grad.batch_size + ) # for actor loss + state_action_value_actor = ( + self._vmap_qvalue_network00( + actor_loss_td, + self._cached_detach_qvalue_network_params, + ) + .get(self.tensor_keys.state_action_value) + .squeeze(-1) + ) + + bc_loss = torch.nn.functional.mse_loss( + tensordict_actor_grad.get(self.tensor_keys.action), + tensordict.get(self.tensor_keys.action), + ) + lmbd = self.alpha / state_action_value_actor[0].abs().mean().detach() + + loss_actor = -lmbd * state_action_value_actor[0] + bc_loss + + metadata = { + "state_action_value_actor": state_action_value_actor.detach(), + "bc_loss": bc_loss, + "lmbd": lmbd, + } + loss_actor = _reduce(loss_actor, reduction=self.reduction) + return loss_actor, metadata + + def value_loss(self, tensordict): + tensordict = tensordict.clone(False) + + act = tensordict.get(self.tensor_keys.action) + + # computing early for reprod + noise = (torch.randn_like(act) * self.policy_noise).clamp( + -self.noise_clip, self.noise_clip + ) + + with torch.no_grad(): + next_td_actor = step_mdp(tensordict).select( + *self.actor_network.in_keys, strict=False + ) # next_observation -> + with self.target_actor_network_params.to_module(self.actor_network): + next_td_actor = self.actor_network(next_td_actor) + next_action = (next_td_actor.get(self.tensor_keys.action) + noise).clamp( + self.min_action, self.max_action + ) + next_td_actor.set( + self.tensor_keys.action, + next_action, + ) + next_val_td = next_td_actor.select( + *self.qvalue_network.in_keys, strict=False + ).expand( + self.num_qvalue_nets, *next_td_actor.batch_size + ) # for next value estimation + next_target_q1q2 = ( + self._vmap_qvalue_network00( + next_val_td, + self.target_qvalue_network_params, + ) + .get(self.tensor_keys.state_action_value) + .squeeze(-1) + ) + # min over the next target qvalues + next_target_qvalue = next_target_q1q2.min(0)[0] + + # set next target qvalues + tensordict.set( + ("next", self.tensor_keys.state_action_value), + next_target_qvalue.unsqueeze(-1), + ) + + qval_td = tensordict.select(*self.qvalue_network.in_keys, strict=False).expand( + self.num_qvalue_nets, + *tensordict.batch_size, + ) + # preditcted current qvalues + current_qvalue = ( + self._vmap_qvalue_network00( + qval_td, + self.qvalue_network_params, + ) + .get(self.tensor_keys.state_action_value) + .squeeze(-1) + ) + + # compute target values for the qvalue loss (reward + gamma * next_target_qvalue * (1 - done)) + target_value = self.value_estimator.value_estimate(tensordict).squeeze(-1) + + td_error = (current_qvalue - target_value).pow(2) + loss_qval = distance_loss( + current_qvalue, + target_value.expand_as(current_qvalue), + loss_function=self.loss_function, + ).sum(0) + metadata = { + "td_error": td_error, + "next_state_value": next_target_qvalue.detach(), + "pred_value": current_qvalue.detach(), + "target_value": target_value.detach(), + } + loss_qval = _reduce(loss_qval, reduction=self.reduction) + return loss_qval, metadata + + @dispatch + def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + tensordict_save = tensordict + loss_actor, metadata_actor = self.actor_loss(tensordict) + loss_qval, metadata_value = self.value_loss(tensordict_save) + tensordict_save.set( + self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0] + ) + if not loss_qval.shape == loss_actor.shape: + raise RuntimeError( + f"QVal and actor loss have different shape: {loss_qval.shape} and {loss_actor.shape}" + ) + td_out = TensorDict( + source={ + "loss_actor": loss_actor, + "loss_qvalue": loss_qval, + **metadata_actor, + **metadata_value, + }, + batch_size=[], + ) + return td_out + + def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams): + if value_type is None: + value_type = self.default_value_estimator + self.value_type = value_type + hp = dict(default_value_kwargs(value_type)) + if hasattr(self, "gamma"): + hp["gamma"] = self.gamma + hp.update(hyperparams) + # we do not need a value network bc the next state value is already passed + if value_type == ValueEstimators.TD1: + self._value_estimator = TD1Estimator(value_network=None, **hp) + elif value_type == ValueEstimators.TD0: + self._value_estimator = TD0Estimator(value_network=None, **hp) + elif value_type == ValueEstimators.GAE: + raise NotImplementedError( + f"Value type {value_type} it not implemented for loss {type(self)}." + ) + elif value_type == ValueEstimators.TDLambda: + self._value_estimator = TDLambdaEstimator(value_network=None, **hp) + else: + raise NotImplementedError(f"Unknown value type {value_type}") + + tensor_keys = { + "value": self.tensor_keys.state_action_value, + "reward": self.tensor_keys.reward, + "done": self.tensor_keys.done, + "terminated": self.tensor_keys.terminated, + } + self._value_estimator.set_keys(**tensor_keys) From 31175039ec436d2a7d2bd678998edb0c995666f8 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 26 Jun 2024 10:19:21 +0200 Subject: [PATCH 02/18] add tests --- .../linux_examples/scripts/run_test.sh | 4 +- sota-implementations/td3+bc/config.yaml | 2 + sota-implementations/td3+bc/td3+bc.py | 28 +- sota-implementations/td3+bc/utils.py | 61 +- test/test_cost.py | 745 +++++++++++++++++- torchrl/objectives/td3_bc.py | 23 +- 6 files changed, 797 insertions(+), 66 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 075489b208d..92499b99a1b 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -51,7 +51,9 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_offline.py \ optim.gradient_steps=55 \ logger.backend= - +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3+bc/td3_bc.py \ + optim.gradient_steps=55 \ + logger.backend= # ==================================================================================== # # ================================ Gymnasium ========================================= # diff --git a/sota-implementations/td3+bc/config.yaml b/sota-implementations/td3+bc/config.yaml index bc8816257e7..bf93b20d082 100644 --- a/sota-implementations/td3+bc/config.yaml +++ b/sota-implementations/td3+bc/config.yaml @@ -40,4 +40,6 @@ logger: exp_name: TD3+BC_${replay_buffer.dataset} mode: online eval_iter: 5000 + eval_steps: 1000 + eval_envs: 1 video: False diff --git a/sota-implementations/td3+bc/td3+bc.py b/sota-implementations/td3+bc/td3+bc.py index 65e28b3c6db..c8c3b24b724 100644 --- a/sota-implementations/td3+bc/td3+bc.py +++ b/sota-implementations/td3+bc/td3+bc.py @@ -2,9 +2,9 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. -"""IQL Example. +"""TD3+BC Example. -This is a self-contained example of an offline IQL training script. +This is a self-contained example of an offline RL TD3+BC training script. The helper functions are coded in the utils.py associated with this script. @@ -32,9 +32,9 @@ ) -@hydra.main(config_path="", config_name="offline_config") +@hydra.main(config_path="", config_name="config") def main(cfg: "DictConfig"): # noqa: F821 - set_gym_backend(cfg.env.backend).set() + set_gym_backend(cfg.env.library).set() # Create logger exp_name = generate_exp_name("TD3BC-offline", cfg.logger.exp_name) @@ -55,7 +55,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Set seeds torch.manual_seed(cfg.env.seed) np.random.seed(cfg.env.seed) - device = cfg.optim.device + device = cfg.network.device if device in ("", None): if torch.cuda.is_available(): device = "cuda:0" @@ -64,9 +64,8 @@ def main(cfg: "DictConfig"): # noqa: F821 device = torch.device(device) # Creante env - train_env, eval_env = make_environment( + eval_env = make_environment( cfg, - cfg.logger.eval_envs, logger=logger, ) @@ -74,10 +73,10 @@ def main(cfg: "DictConfig"): # noqa: F821 replay_buffer = make_offline_replay_buffer(cfg.replay_buffer) # Create agent - model, actor_model_explore = make_td3_agent(cfg, train_env, eval_env, device) + model, _ = make_td3_agent(cfg, eval_env, device) # Create loss - loss_module, target_net_updater = make_loss_module(cfg.loss, model) + loss_module, target_net_updater = make_loss_module(cfg.optim, model) # Create optimizer optimizer_actor, optimizer_critic = make_optimizer(cfg.optim, loss_module) @@ -113,6 +112,8 @@ def main(cfg: "DictConfig"): # noqa: F821 optimizer_critic.step() q_loss.item() + to_log = {"q_loss": q_loss.item()} + # Update actor if update_actor: actor_loss, actorloss_metadata = loss_module.actor_loss(sampled_tensordict) @@ -123,13 +124,8 @@ def main(cfg: "DictConfig"): # noqa: F821 # Update target params target_net_updater.step() - # log metrics - to_log = { - "loss_actor": actor_loss.item(), - "loss_qvalue": q_loss.item(), - "bc_loss": actorloss_metadata.bc_loss.item(), - "lambda": actorloss_metadata.actor_loss.item(), - } + to_log["actor_loss"] = actor_loss.item() + to_log.update(actorloss_metadata) # evaluation if i % evaluation_interval == 0: diff --git a/sota-implementations/td3+bc/utils.py b/sota-implementations/td3+bc/utils.py index 498b90fbf58..2b4c1f96146 100644 --- a/sota-implementations/td3+bc/utils.py +++ b/sota-implementations/td3+bc/utils.py @@ -80,7 +80,7 @@ def make_environment(cfg, logger=None): """Make environments for training and evaluation.""" partial = functools.partial(env_maker, cfg=cfg) parallel_env = ParallelEnv( - cfg.collector.env_per_collector, + cfg.logger.eval_envs, EnvCreator(partial), serial_for_single=True, ) @@ -88,21 +88,21 @@ def make_environment(cfg, logger=None): train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps) - partial = functools.partial(env_maker, cfg=cfg, from_pixels=cfg.logger.video) - trsf_clone = train_env.transform.clone() - if cfg.logger.video: - trsf_clone.insert( - 0, VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) - ) - eval_env = TransformedEnv( - ParallelEnv( - cfg.collector.env_per_collector, - EnvCreator(partial), - serial_for_single=True, - ), - trsf_clone, - ) - return train_env, eval_env + # partial = functools.partial(env_maker, cfg=cfg, from_pixels=cfg.logger.video) + # trsf_clone = train_env.transform.clone() + # if cfg.logger.video: + # trsf_clone.insert( + # 0, VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) + # ) + # eval_env = TransformedEnv( + # ParallelEnv( + # cfg.logger.eval_envs, + # EnvCreator(partial), + # serial_for_single=True, + # ), + # trsf_clone, + # ) + return train_env # ==================================================================== @@ -130,7 +130,7 @@ def make_offline_replay_buffer(rb_cfg): # ----- -def make_td3_agent(cfg, train_env, eval_env, device): +def make_td3_agent(cfg, train_env, device): """Make TD3 agent.""" # Define Actor Network in_keys = ["observation"] @@ -182,12 +182,11 @@ def make_td3_agent(cfg, train_env, eval_env, device): # init nets with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - td = eval_env.reset() + td = train_env.reset() td = td.to(device) for net in model: net(td) del td - eval_env.close() # Exploration wrappers: actor_model_explore = AdditiveGaussianWrapper( @@ -213,18 +212,18 @@ def make_loss_module(cfg, model): actor_network=model[0], qvalue_network=model[1], num_qvalue_nets=2, - loss_function=cfg.optim.loss_function, + loss_function=cfg.loss_function, delay_actor=True, delay_qvalue=True, action_spec=model[0][1].spec, - policy_noise=cfg.optim.policy_noise, - noise_clip=cfg.optim.noise_clip, - alpha=cfg.optim.alpha, + policy_noise=cfg.policy_noise, + noise_clip=cfg.noise_clip, + alpha=cfg.alpha, ) - loss_module.make_value_estimator(gamma=cfg.optim.gamma) + loss_module.make_value_estimator(gamma=cfg.gamma) # Define Target Network Updater - target_net_updater = SoftUpdate(loss_module, eps=cfg.optim.target_update_polyak) + target_net_updater = SoftUpdate(loss_module, eps=cfg.target_update_polyak) return loss_module, target_net_updater @@ -234,15 +233,15 @@ def make_optimizer(cfg, loss_module): optimizer_actor = optim.Adam( actor_params, - lr=cfg.optim.lr, - weight_decay=cfg.optim.weight_decay, - eps=cfg.optim.adam_eps, + lr=cfg.lr, + weight_decay=cfg.weight_decay, + eps=cfg.adam_eps, ) optimizer_critic = optim.Adam( critic_params, - lr=cfg.optim.lr, - weight_decay=cfg.optim.weight_decay, - eps=cfg.optim.adam_eps, + lr=cfg.lr, + weight_decay=cfg.weight_decay, + eps=cfg.adam_eps, ) return optimizer_actor, optimizer_critic diff --git a/test/test_cost.py b/test/test_cost.py index 76fc4e651f4..24d22ec4902 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -114,6 +114,7 @@ PPOLoss, QMixerLoss, SACLoss, + TD3BCLoss, TD3Loss, ) from torchrl.objectives.common import LossModule @@ -261,9 +262,9 @@ def __init__(self): self.vmap_model = _vmap_func( self.model, (None, 0), - randomness="error" - if vmap_randomness == "error" - else self.vmap_randomness, + randomness=( + "error" if vmap_randomness == "error" else self.vmap_randomness + ), ) def forward(self, td): @@ -319,9 +320,9 @@ def _create_mock_actor( spec=CompositeSpec( { "action": action_spec, - "action_value" - if action_value_key is None - else action_value_key: None, + ( + "action_value" if action_value_key is None else action_value_key + ): None, "chosen_action_value": None, }, shape=[], @@ -2714,6 +2715,732 @@ def test_td3_reduction(self, reduction): assert loss[key].shape == torch.Size([]) +@pytest.mark.skipif( + not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" +) +class TestTD3BC(LossModuleTestBase): + seed = 0 + + def _create_mock_actor( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + in_keys=None, + out_keys=None, + dropout=0.0, + ): + # Actor + action_spec = BoundedTensorSpec( + -torch.ones(action_dim), torch.ones(action_dim), (action_dim,) + ) + module = nn.Sequential( + nn.Linear(obs_dim, obs_dim), + nn.Dropout(dropout), + nn.Linear(obs_dim, action_dim), + ) + actor = Actor( + spec=action_spec, module=module, in_keys=in_keys, out_keys=out_keys + ) + return actor.to(device) + + def _create_mock_value( + self, + batch=2, + obs_dim=3, + action_dim=4, + device="cpu", + out_keys=None, + action_key="action", + observation_key="observation", + ): + # Actor + class ValueClass(nn.Module): + def __init__(self): + super().__init__() + self.linear = nn.Linear(obs_dim + action_dim, 1) + + def forward(self, obs, act): + return self.linear(torch.cat([obs, act], -1)) + + module = ValueClass() + value = ValueOperator( + module=module, + in_keys=[observation_key, action_key], + out_keys=out_keys, + ) + return value.to(device) + + def _create_mock_distributional_actor( + self, batch=2, obs_dim=3, action_dim=4, atoms=5, vmin=1, vmax=5 + ): + raise NotImplementedError + + def _create_mock_common_layer_setup( + self, n_obs=3, n_act=4, ncells=4, batch=2, n_hidden=2 + ): + common = MLP( + num_cells=ncells, + in_features=n_obs, + depth=3, + out_features=n_hidden, + ) + actor_net = MLP( + num_cells=ncells, + in_features=n_hidden, + depth=1, + out_features=2 * n_act, + ) + value = MLP( + in_features=n_hidden + n_act, + num_cells=ncells, + depth=1, + out_features=1, + ) + batch = [batch] + td = TensorDict( + { + "obs": torch.randn(*batch, n_obs), + "action": torch.randn(*batch, n_act), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + "next": { + "obs": torch.randn(*batch, n_obs), + "reward": torch.randn(*batch, 1), + "done": torch.zeros(*batch, 1, dtype=torch.bool), + "terminated": torch.zeros(*batch, 1, dtype=torch.bool), + }, + }, + batch, + ) + common = Mod(common, in_keys=["obs"], out_keys=["hidden"]) + actor = ProbSeq( + common, + Mod(actor_net, in_keys=["hidden"], out_keys=["param"]), + Mod(NormalParamExtractor(), in_keys=["param"], out_keys=["loc", "scale"]), + ProbMod( + in_keys=["loc", "scale"], + out_keys=["action"], + distribution_class=TanhNormal, + return_log_prob=True, + ), + ) + value_head = Mod( + value, in_keys=["hidden", "action"], out_keys=["state_action_value"] + ) + value = Seq(common, value_head) + return actor, value, common, td + + def _create_mock_data_td3bc( + self, + batch=8, + obs_dim=3, + action_dim=4, + atoms=None, + device="cpu", + action_key="action", + observation_key="observation", + reward_key="reward", + done_key="done", + terminated_key="terminated", + ): + # create a tensordict + obs = torch.randn(batch, obs_dim, device=device) + next_obs = torch.randn(batch, obs_dim, device=device) + if atoms: + raise NotImplementedError + else: + action = torch.randn(batch, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, 1, device=device) + done = torch.zeros(batch, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch,), + source={ + observation_key: obs, + "next": { + observation_key: next_obs, + done_key: done, + terminated_key: terminated, + reward_key: reward, + }, + action_key: action, + }, + device=device, + ) + return td + + def _create_seq_mock_data_td3bc( + self, batch=8, T=4, obs_dim=3, action_dim=4, atoms=None, device="cpu" + ): + # create a tensordict + total_obs = torch.randn(batch, T + 1, obs_dim, device=device) + obs = total_obs[:, :T] + next_obs = total_obs[:, 1:] + if atoms: + action = torch.randn(batch, T, atoms, action_dim, device=device).clamp( + -1, 1 + ) + else: + action = torch.randn(batch, T, action_dim, device=device).clamp(-1, 1) + reward = torch.randn(batch, T, 1, device=device) + done = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + terminated = torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + mask = ~torch.zeros(batch, T, 1, dtype=torch.bool, device=device) + td = TensorDict( + batch_size=(batch, T), + source={ + "observation": obs * mask.to(obs.dtype), + "next": { + "observation": next_obs * mask.to(obs.dtype), + "reward": reward * mask.to(obs.dtype), + "done": done, + "terminated": terminated, + }, + "collector": {"mask": mask}, + "action": action * mask.to(obs.dtype), + }, + names=[None, "time"], + device=device, + ) + return td + + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize( + "delay_actor, delay_qvalue", [(False, False), (True, True)] + ) + @pytest.mark.parametrize("policy_noise", [0.1, 1.0]) + @pytest.mark.parametrize("noise_clip", [0.1, 1.0]) + @pytest.mark.parametrize("alpha", [0.1, 6.0]) + @pytest.mark.parametrize("td_est", list(ValueEstimators) + [None]) + @pytest.mark.parametrize("use_action_spec", [True, False]) + @pytest.mark.parametrize("dropout", [0.0, 0.1]) + def test_td3bc( + self, + delay_actor, + delay_qvalue, + device, + policy_noise, + noise_clip, + alpha, + td_est, + use_action_spec, + dropout, + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device, dropout=dropout) + value = self._create_mock_value(device=device) + td = self._create_mock_data_td3bc(device=device) + if use_action_spec: + action_spec = actor.spec + bounds = None + else: + bounds = (-1, 1) + action_spec = None + loss_fn = TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, + loss_function="l2", + policy_noise=policy_noise, + noise_clip=noise_clip, + alpha=alpha, + delay_actor=delay_actor, + delay_qvalue=delay_qvalue, + ) + if td_est in (ValueEstimators.GAE, ValueEstimators.VTrace): + with pytest.raises(NotImplementedError): + loss_fn.make_value_estimator(td_est) + return + if td_est is not None: + loss_fn.make_value_estimator(td_est) + with ( + pytest.warns( + UserWarning, + match="No target network updater has been associated with this loss module", + ) + if (delay_actor or delay_qvalue) + else contextlib.nullcontext() + ): + with _check_td_steady(td): + loss = loss_fn(td) + + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values(True, True) + ) + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values(True, True) + ) + elif k == "loss_qvalue": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values(True, True) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) + ) + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ).backward() + named_parameters = list(loss_fn.named_parameters()) + named_buffers = list(loss_fn.named_buffers()) + + assert len({p for n, p in named_parameters}) == len(list(named_parameters)) + assert len({p for n, p in named_buffers}) == len(list(named_buffers)) + + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize( + "delay_actor, delay_qvalue", [(False, False), (True, True)] + ) + @pytest.mark.parametrize("policy_noise", [0.1]) + @pytest.mark.parametrize("noise_clip", [0.1]) + @pytest.mark.parametrize("alpha", [0.1]) + @pytest.mark.parametrize("use_action_spec", [True, False]) + def test_td3bc_state_dict( + self, + delay_actor, + delay_qvalue, + device, + policy_noise, + noise_clip, + alpha, + use_action_spec, + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + if use_action_spec: + action_spec = actor.spec + bounds = None + else: + bounds = (-1, 1) + action_spec = None + loss_fn = TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, + loss_function="l2", + policy_noise=policy_noise, + noise_clip=noise_clip, + alpha=alpha, + delay_actor=delay_actor, + delay_qvalue=delay_qvalue, + ) + sd = loss_fn.state_dict() + loss_fn2 = TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, + loss_function="l2", + policy_noise=policy_noise, + noise_clip=noise_clip, + alpha=alpha, + delay_actor=delay_actor, + delay_qvalue=delay_qvalue, + ) + loss_fn2.load_state_dict(sd) + + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("separate_losses", [False, True]) + def test_td3bc_separate_losses( + self, + device, + separate_losses, + n_act=4, + ): + torch.manual_seed(self.seed) + actor, value, common, td = self._create_mock_common_layer_setup(n_act=n_act) + loss_fn = TD3BCLoss( + actor, + value, + action_spec=BoundedTensorSpec(shape=(n_act,), low=-1, high=1), + loss_function="l2", + separate_losses=separate_losses, + ) + with pytest.warns(UserWarning, match="No target network updater has been"): + loss = loss_fn(td) + + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values(True, True) + ) + # check that losses are independent + for k in loss.keys(): + if not k.startswith("loss"): + continue + loss[k].sum().backward(retain_graph=True) + if k == "loss_actor": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values(True, True) + ) + elif k == "loss_qvalue": + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.actor_network_params.values(True, True) + ) + if separate_losses: + common_layers_no = len(list(common.parameters())) + common_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + ) + assert all( + (p.grad is None) or (p.grad == 0).all() + for p in common_layers + ) + qvalue_layers = itertools.islice( + loss_fn.qvalue_network_params.values(True, True), + common_layers_no, + None, + ) + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in qvalue_layers + ) + else: + assert not any( + (p.grad is None) or (p.grad == 0).all() + for p in loss_fn.qvalue_network_params.values(True, True) + ) + + else: + raise NotImplementedError(k) + loss_fn.zero_grad() + + @pytest.mark.skipif(not _has_functorch, reason="functorch not installed") + @pytest.mark.parametrize("n", range(1, 4)) + @pytest.mark.parametrize("device", get_default_devices()) + @pytest.mark.parametrize("delay_actor,delay_qvalue", [(False, False), (True, True)]) + @pytest.mark.parametrize("policy_noise", [0.1, 1.0]) + @pytest.mark.parametrize("noise_clip", [0.1, 1.0]) + @pytest.mark.parametrize("alpha", [0.1, 6.0]) + def test_td3bc_batcher( + self, + n, + delay_actor, + delay_qvalue, + device, + policy_noise, + noise_clip, + alpha, + gamma=0.9, + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + td = self._create_seq_mock_data_td3bc(device=device) + loss_fn = TD3BCLoss( + actor, + value, + action_spec=actor.spec, + policy_noise=policy_noise, + noise_clip=noise_clip, + alpha=alpha, + delay_qvalue=delay_qvalue, + delay_actor=delay_actor, + ) + + ms = MultiStep(gamma=gamma, n_steps=n).to(device) + + td_clone = td.clone() + ms_td = ms(td_clone) + + torch.manual_seed(0) + np.random.seed(0) + + with ( + pytest.warns(UserWarning, match="No target network updater has been") + if (delay_qvalue or delay_actor) + else contextlib.nullcontext() + ), _check_td_steady(ms_td): + loss_ms = loss_fn(ms_td) + assert loss_fn.tensor_keys.priority in ms_td.keys() + + if delay_qvalue or delay_actor: + SoftUpdate(loss_fn, eps=0.5) + + with torch.no_grad(): + torch.manual_seed(0) # log-prob is computed with a random action + np.random.seed(0) + loss = loss_fn(td) + + if n == 1: + assert_allclose_td(td, ms_td.select(*list(td.keys(True, True)))) + _loss = sum( + [item for name, item in loss.items() if name.startswith("loss_")] + ) + _loss_ms = sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ) + assert ( + abs(_loss - _loss_ms) < 1e-3 + ), f"found abs(loss-loss_ms) = {abs(loss - loss_ms):4.5f} for n=0" + else: + with pytest.raises(AssertionError): + assert_allclose_td(loss, loss_ms) + + sum( + [item for name, item in loss_ms.items() if name.startswith("loss_")] + ).backward() + named_parameters = loss_fn.named_parameters() + + for name, p in named_parameters: + if not name.startswith("target_"): + assert ( + p.grad is not None and p.grad.norm() > 0.0 + ), f"parameter {name} (shape: {p.shape}) has a null gradient" + else: + assert ( + p.grad is None or p.grad.norm() == 0.0 + ), f"target parameter {name} (shape: {p.shape}) has a non-null gradient" + + # Check param update effect on targets + target_actor = loss_fn.target_actor_network_params.clone().values( + include_nested=True, leaves_only=True + ) + target_qvalue = loss_fn.target_qvalue_network_params.clone().values( + include_nested=True, leaves_only=True + ) + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + target_actor2 = loss_fn.target_actor_network_params.clone().values( + include_nested=True, leaves_only=True + ) + target_qvalue2 = loss_fn.target_qvalue_network_params.clone().values( + include_nested=True, leaves_only=True + ) + if loss_fn.delay_actor: + assert all((p1 == p2).all() for p1, p2 in zip(target_actor, target_actor2)) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_actor, target_actor2) + ) + if loss_fn.delay_qvalue: + assert all( + (p1 == p2).all() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + else: + assert not any( + (p1 == p2).any() for p1, p2 in zip(target_qvalue, target_qvalue2) + ) + + # check that policy is updated after parameter update + actorp_set = set(actor.parameters()) + loss_fnp_set = set(loss_fn.parameters()) + assert len(actorp_set.intersection(loss_fnp_set)) == len(actorp_set) + parameters = [p.clone() for p in actor.parameters()] + for p in loss_fn.parameters(): + if p.requires_grad: + p.data += torch.randn_like(p) + assert all((p1 != p2).all() for p1, p2 in zip(parameters, actor.parameters())) + + @pytest.mark.parametrize( + "td_est", [ValueEstimators.TD1, ValueEstimators.TD0, ValueEstimators.TDLambda] + ) + def test_td3bc_tensordict_keys(self, td_est): + actor = self._create_mock_actor() + value = self._create_mock_value() + loss_fn = TD3BCLoss( + actor, + value, + action_spec=actor.spec, + ) + + default_keys = { + "priority": "td_error", + "state_action_value": "state_action_value", + "action": "action", + "reward": "reward", + "done": "done", + "terminated": "terminated", + } + + self.tensordict_keys_test( + loss_fn, + default_keys=default_keys, + td_est=td_est, + ) + + value = self._create_mock_value(out_keys=["state_action_value_test"]) + loss_fn = TD3BCLoss( + actor, + value, + action_spec=actor.spec, + ) + key_mapping = { + "state_action_value": ("value", "state_action_value_test"), + "reward": ("reward", "reward_test"), + "done": ("done", ("done", "test")), + "terminated": ("terminated", ("terminated", "test")), + } + self.set_advantage_keys_through_loss_test(loss_fn, td_est, key_mapping) + + @pytest.mark.parametrize("spec", [True, False]) + @pytest.mark.parametrize("bounds", [True, False]) + def test_constructor(self, spec, bounds): + actor = self._create_mock_actor() + value = self._create_mock_value() + action_spec = actor.spec if spec else None + bounds = (-1, 1) if bounds else None + if (bounds is not None and action_spec is not None) or ( + bounds is None and action_spec is None + ): + with pytest.raises(ValueError, match="but not both"): + TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, + ) + return + TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, + ) + + # TODO: test for action_key, atm the action key of the TD3+BC loss is not configurable, + # since it is used in it's constructor + @pytest.mark.parametrize("observation_key", ["observation", "observation2"]) + @pytest.mark.parametrize("reward_key", ["reward", "reward2"]) + @pytest.mark.parametrize("done_key", ["done", "done2"]) + @pytest.mark.parametrize("terminated_key", ["terminated", "terminated2"]) + def test_td3bc_notensordict( + self, observation_key, reward_key, done_key, terminated_key + ): + torch.manual_seed(self.seed) + actor = self._create_mock_actor(in_keys=[observation_key]) + qvalue = self._create_mock_value( + observation_key=observation_key, out_keys=["state_action_value"] + ) + td = self._create_mock_data_td3bc( + observation_key=observation_key, + reward_key=reward_key, + done_key=done_key, + terminated_key=terminated_key, + ) + loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec) + loss.set_keys(reward=reward_key, done=done_key, terminated=terminated_key) + + kwargs = { + observation_key: td.get(observation_key), + f"next_{reward_key}": td.get(("next", reward_key)), + f"next_{done_key}": td.get(("next", done_key)), + f"next_{terminated_key}": td.get(("next", terminated_key)), + f"next_{observation_key}": td.get(("next", observation_key)), + "action": td.get("action"), + } + td = TensorDict(kwargs, td.batch_size).unflatten_keys("_") + + with pytest.warns(UserWarning, match="No target network updater has been"): + torch.manual_seed(0) + loss_val_td = loss(td) + torch.manual_seed(0) + loss_val = loss(**kwargs) + loss_val_reconstruct = TensorDict(dict(zip(loss.out_keys, loss_val)), []) + assert_allclose_td(loss_val_reconstruct, loss_val_td) + + # test select + loss.select_out_keys("loss_actor", "loss_qvalue") + torch.manual_seed(0) + if torch.__version__ >= "2.0.0": + loss_actor, loss_qvalue = loss(**kwargs) + else: + with pytest.raises( + RuntimeError, + match="You are likely using tensordict.nn.dispatch with keyword arguments", + ): + loss_actor, loss_qvalue = loss(**kwargs) + return + + assert loss_actor == loss_val_td["loss_actor"] + assert loss_qvalue == loss_val_td["loss_qvalue"] + + @pytest.mark.parametrize("reduction", [None, "none", "mean", "sum"]) + def test_td3bc_reduction(self, reduction): + torch.manual_seed(self.seed) + device = ( + torch.device("cpu") + if torch.cuda.device_count() == 0 + else torch.device("cuda") + ) + actor = self._create_mock_actor(device=device) + value = self._create_mock_value(device=device) + td = self._create_mock_data_td3bc(device=device) + action_spec = actor.spec + bounds = None + loss_fn = TD3BCLoss( + actor, + value, + action_spec=action_spec, + bounds=bounds, + loss_function="l2", + delay_qvalue=False, + delay_actor=False, + reduction=reduction, + ) + loss_fn.make_value_estimator() + loss = loss_fn(td) + if reduction == "none": + for key in loss.keys(): + if key.startswith("loss"): + assert loss[key].shape == td.shape + else: + for key in loss.keys(): + if not key.startswith("loss"): + continue + assert loss[key].shape == torch.Size([]) + + @pytest.mark.skipif( not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" ) @@ -5686,9 +6413,9 @@ def _create_mock_actor( spec=CompositeSpec( { "action": action_spec, - "action_value" - if action_value_key is None - else action_value_key: None, + ( + "action_value" if action_value_key is None else action_value_key + ): None, "chosen_action_value": None, }, shape=[], diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index 62d2e4f2585..c53d9ad26fc 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -84,7 +84,7 @@ class TD3BCLoss(LossModule): >>> from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal >>> from torchrl.modules.tensordict_module.actors import Actor, ProbabilisticActor, ValueOperator >>> from torchrl.modules.tensordict_module.common import SafeModule - >>> from torchrl.objectives.td3 import TD3Loss + >>> from torchrl.objectives.td3_bc import TD3BCLoss >>> from tensordict import TensorDict >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) @@ -102,7 +102,7 @@ class TD3BCLoss(LossModule): >>> qvalue = ValueOperator( ... module=module, ... in_keys=['observation', 'action']) - >>> loss = TD3Loss(actor, qvalue, action_spec=actor.spec) + >>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec) >>> batch = [2, ] >>> action = spec.rand(batch) >>> data = TensorDict({ @@ -116,12 +116,14 @@ class TD3BCLoss(LossModule): >>> loss(data) TensorDict( fields={ + bc_loss: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), + lmbd: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), loss_qvalue: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - next_state_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - pred_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - state_action_value_actor: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), - target_value: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, + next_state_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False), + pred_value: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False), + state_action_value_actor: Tensor(shape=torch.Size([2, 2]), device=cpu, dtype=torch.float32, is_shared=False), + target_value: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=None, is_shared=False) @@ -131,14 +133,14 @@ class TD3BCLoss(LossModule): the expected keyword arguments are: ``["action", "next_reward", "next_done", "next_terminated"]`` + in_keys of the actor and qvalue network The return value is a tuple of tensors in the following order: - ``["loss_actor", "loss_qvalue", "pred_value", "state_action_value_actor", "next_state_value", "target_value",]``. + ``["loss_actor", "loss_qvalue", "bc_loss, "lmbd", "pred_value", "state_action_value_actor", "next_state_value", "target_value",]``. Examples: >>> import torch >>> from torch import nn >>> from torchrl.data import BoundedTensorSpec >>> from torchrl.modules.tensordict_module.actors import Actor, ValueOperator - >>> from torchrl.objectives.td3 import TD3Loss + >>> from torchrl.objectives.td3_bc import TD3BCLoss >>> n_act, n_obs = 4, 3 >>> spec = BoundedTensorSpec(-torch.ones(n_act), torch.ones(n_act), (n_act,)) >>> module = nn.Linear(n_obs, n_act) @@ -155,7 +157,7 @@ class TD3BCLoss(LossModule): >>> qvalue = ValueOperator( ... module=module, ... in_keys=['observation', 'action']) - >>> loss = TD3Loss(actor, qvalue, action_spec=actor.spec) + >>> loss = TD3BCLoss(actor, qvalue, action_spec=actor.spec) >>> _ = loss.select_out_keys("loss_actor", "loss_qvalue") >>> batch = [2, ] >>> action = spec.rand(batch) @@ -206,6 +208,8 @@ class _AcceptedKeys: out_keys = [ "loss_actor", "loss_qvalue", + "bc_loss", + "lmbd", "pred_value", "state_action_value_actor", "next_state_value", @@ -275,6 +279,7 @@ def __init__( self.loss_function = loss_function self.policy_noise = policy_noise self.noise_clip = noise_clip + self.alpha = alpha if not ((action_spec is not None) ^ (bounds is not None)): raise ValueError( "One of 'bounds' and 'action_spec' must be provided, " From c996c66c5ebcb789ed8a0f1fcc2caf5d76bdd82a Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 26 Jun 2024 10:29:04 +0200 Subject: [PATCH 03/18] udpate examples, tuts and demos in readme --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index 4f2cd68b0f2..207365f095a 100644 --- a/README.md +++ b/README.md @@ -501,6 +501,7 @@ A series of [examples](https://github.com/pytorch/rl/blob/main/examples/) are pr - [IQL](https://github.com/pytorch/rl/blob/main/sota-implementations/iql/iql_offline.py) - [CQL](https://github.com/pytorch/rl/blob/main/sota-implementations/cql/cql_offline.py) - [TD3](https://github.com/pytorch/rl/blob/main/sota-implementations/td3/td3.py) +- [TD3+BC](https://github.com/pytorch/rl/blob/main/sota-implementations/td3+bc/td3+bc.py) - [A2C](https://github.com/pytorch/rl/blob/main/examples/a2c_old/a2c.py) - [PPO](https://github.com/pytorch/rl/blob/main/sota-implementations/ppo/ppo.py) - [SAC](https://github.com/pytorch/rl/blob/main/sota-implementations/sac/sac.py) From acbdd006aed6d63dc5d0511ce93ca3fb2a3573d2 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 26 Jun 2024 10:30:55 +0200 Subject: [PATCH 04/18] update docs --- docs/source/reference/objectives.rst | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/docs/source/reference/objectives.rst b/docs/source/reference/objectives.rst index c2f43d8e9b6..ef9bc1ee907 100644 --- a/docs/source/reference/objectives.rst +++ b/docs/source/reference/objectives.rst @@ -160,6 +160,15 @@ TD3 TD3Loss +TD3+BC +---- + +.. autosummary:: + :toctree: generated/ + :template: rl_template_noinherit.rst + + TD3BCLoss + PPO --- From 7fa64a7008e1e5108a072f57d8f5268f4c19cdb0 Mon Sep 17 00:00:00 2001 From: BY571 Date: Wed, 26 Jun 2024 10:47:11 +0200 Subject: [PATCH 05/18] update sota checks --- sota-check/run_td3bc.sh | 26 ++++++++++++++++++++++++++ sota-check/submitit-release-check.sh | 1 + 2 files changed, 27 insertions(+) create mode 100644 sota-check/run_td3bc.sh diff --git a/sota-check/run_td3bc.sh b/sota-check/run_td3bc.sh new file mode 100644 index 00000000000..26a22b197b8 --- /dev/null +++ b/sota-check/run_td3bc.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +#SBATCH --job-name=td3bc_offline +#SBATCH --ntasks=32 +#SBATCH --cpus-per-task=1 +#SBATCH --gres=gpu:1 +#SBATCH --output=slurm_logs/td3bc_offline_%j.txt +#SBATCH --error=slurm_errors/td3bc_offline_%j.txt + +current_commit=$(git rev-parse --short HEAD) +project_name="torchrl-example-check-$current_commit" +group_name="td3bc_offline" +export PYTHONPATH=$(dirname $(dirname $PWD)) +python $PYTHONPATH/sota-implementations/td3+bc/td3+bc.py \ + logger.backend=wandb \ + logger.project_name="$project_name" \ + logger.group_name="$group_name" + +# Capture the exit status of the Python command +exit_status=$? +# Write the exit status to a file +if [ $exit_status -eq 0 ]; then + echo "${group_name}_${SLURM_JOB_ID}=success" >>> report.log +else + echo "${group_name}_${SLURM_JOB_ID}=error" >>> report.log +fi diff --git a/sota-check/submitit-release-check.sh b/sota-check/submitit-release-check.sh index cad2783c653..515ac06a50b 100755 --- a/sota-check/submitit-release-check.sh +++ b/sota-check/submitit-release-check.sh @@ -65,6 +65,7 @@ scripts=( run_ppo_mujoco.sh run_sac.sh run_td3.sh + run_td3bc.sh run_dt.sh run_dt_online.sh ) From 4fa034433a106acd98c70547e245f02075279eae Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 27 Jun 2024 11:54:32 +0200 Subject: [PATCH 06/18] rename td3 bc sota-implementation --- sota-implementations/{td3+bc => td3_bc}/config.yaml | 0 sota-implementations/{td3+bc/td3+bc.py => td3_bc/td3_bc.py} | 0 sota-implementations/{td3+bc => td3_bc}/utils.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename sota-implementations/{td3+bc => td3_bc}/config.yaml (100%) rename sota-implementations/{td3+bc/td3+bc.py => td3_bc/td3_bc.py} (100%) rename sota-implementations/{td3+bc => td3_bc}/utils.py (100%) diff --git a/sota-implementations/td3+bc/config.yaml b/sota-implementations/td3_bc/config.yaml similarity index 100% rename from sota-implementations/td3+bc/config.yaml rename to sota-implementations/td3_bc/config.yaml diff --git a/sota-implementations/td3+bc/td3+bc.py b/sota-implementations/td3_bc/td3_bc.py similarity index 100% rename from sota-implementations/td3+bc/td3+bc.py rename to sota-implementations/td3_bc/td3_bc.py diff --git a/sota-implementations/td3+bc/utils.py b/sota-implementations/td3_bc/utils.py similarity index 100% rename from sota-implementations/td3+bc/utils.py rename to sota-implementations/td3_bc/utils.py From 13d700dcd3a0d44967660d1f263660e7e826fcd8 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 27 Jun 2024 11:54:53 +0200 Subject: [PATCH 07/18] update sota-check --- sota-check/run_td3bc.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-check/run_td3bc.sh b/sota-check/run_td3bc.sh index 26a22b197b8..0fefb3ecd6f 100644 --- a/sota-check/run_td3bc.sh +++ b/sota-check/run_td3bc.sh @@ -11,7 +11,7 @@ current_commit=$(git rev-parse --short HEAD) project_name="torchrl-example-check-$current_commit" group_name="td3bc_offline" export PYTHONPATH=$(dirname $(dirname $PWD)) -python $PYTHONPATH/sota-implementations/td3+bc/td3+bc.py \ +python $PYTHONPATH/sota-implementations/td3_bc/td3_bc.py \ logger.backend=wandb \ logger.project_name="$project_name" \ logger.group_name="$group_name" From 9f27c077adab7a3d98d9dc5c3351f9fe13b62d0c Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 27 Jun 2024 11:55:54 +0200 Subject: [PATCH 08/18] update test td3_bc naming --- .github/unittest/linux_examples/scripts/run_test.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index 92499b99a1b..fd1b60cda3a 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -51,7 +51,7 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_offline.py \ optim.gradient_steps=55 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3+bc/td3_bc.py \ +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3_bc/td3_bc.py \ optim.gradient_steps=55 \ logger.backend= # ==================================================================================== # From fe0a51fe44d69879410ed2ab8ada2738f08be2ec Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 27 Jun 2024 11:58:17 +0200 Subject: [PATCH 09/18] update pbar def and loop --- sota-implementations/td3_bc/td3_bc.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/sota-implementations/td3_bc/td3_bc.py b/sota-implementations/td3_bc/td3_bc.py index c8c3b24b724..988d6a93651 100644 --- a/sota-implementations/td3_bc/td3_bc.py +++ b/sota-implementations/td3_bc/td3_bc.py @@ -81,16 +81,15 @@ def main(cfg: "DictConfig"): # noqa: F821 # Create optimizer optimizer_actor, optimizer_critic = make_optimizer(cfg.optim, loss_module) - pbar = tqdm.tqdm(total=cfg.optim.gradient_steps) - gradient_steps = cfg.optim.gradient_steps evaluation_interval = cfg.logger.eval_iter eval_steps = cfg.logger.eval_steps delayed_updates = cfg.optim.policy_update_delay update_counter = 0 + pbar = tqdm.tqdm(range(gradient_steps)) # Training loop start_time = time.time() - for i in range(gradient_steps): + for i in pbar: pbar.update(1) # Update actor every delayed_updates update_counter += 1 From d0810f387fac8271418a478417a14c70ac291b0f Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 27 Jun 2024 11:58:59 +0200 Subject: [PATCH 10/18] fix --- sota-implementations/td3_bc/td3_bc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/td3_bc/td3_bc.py b/sota-implementations/td3_bc/td3_bc.py index 988d6a93651..46fb923403e 100644 --- a/sota-implementations/td3_bc/td3_bc.py +++ b/sota-implementations/td3_bc/td3_bc.py @@ -98,7 +98,7 @@ def main(cfg: "DictConfig"): # noqa: F821 # Sample from replay buffer sampled_tensordict = replay_buffer.sample() if sampled_tensordict.device != device: - sampled_tensordict = sampled_tensordict.to(device, non_blocking=True) + sampled_tensordict = sampled_tensordict.to(device) else: sampled_tensordict = sampled_tensordict.clone() From eedd42856a8f2de6862b20e8c8688699156bfc73 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 27 Jun 2024 11:59:55 +0200 Subject: [PATCH 11/18] remove eval env creation in utils --- sota-implementations/td3_bc/utils.py | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/sota-implementations/td3_bc/utils.py b/sota-implementations/td3_bc/utils.py index 2b4c1f96146..db6eba24440 100644 --- a/sota-implementations/td3_bc/utils.py +++ b/sota-implementations/td3_bc/utils.py @@ -87,21 +87,6 @@ def make_environment(cfg, logger=None): parallel_env.set_seed(cfg.env.seed) train_env = apply_env_transforms(parallel_env, cfg.env.max_episode_steps) - - # partial = functools.partial(env_maker, cfg=cfg, from_pixels=cfg.logger.video) - # trsf_clone = train_env.transform.clone() - # if cfg.logger.video: - # trsf_clone.insert( - # 0, VideoRecorder(logger, tag="rendering/test", in_keys=["pixels"]) - # ) - # eval_env = TransformedEnv( - # ParallelEnv( - # cfg.logger.eval_envs, - # EnvCreator(partial), - # serial_for_single=True, - # ), - # trsf_clone, - # ) return train_env From 6de6d574e94804bb139c960796f05893548e0078 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 27 Jun 2024 12:01:57 +0200 Subject: [PATCH 12/18] fixes in utils --- sota-implementations/td3_bc/utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sota-implementations/td3_bc/utils.py b/sota-implementations/td3_bc/utils.py index db6eba24440..3772eefccde 100644 --- a/sota-implementations/td3_bc/utils.py +++ b/sota-implementations/td3_bc/utils.py @@ -167,7 +167,7 @@ def make_td3_agent(cfg, train_env, device): # init nets with torch.no_grad(), set_exploration_type(ExplorationType.RANDOM): - td = train_env.reset() + td = train_env.fake_tensordict() td = td.to(device) for net in model: net(td) @@ -213,8 +213,8 @@ def make_loss_module(cfg, model): def make_optimizer(cfg, loss_module): - critic_params = list(loss_module.qvalue_network_params.flatten_keys().values()) - actor_params = list(loss_module.actor_network_params.flatten_keys().values()) + critic_params = list(loss_module.qvalue_network_params.values(True, True)) + actor_params = list(loss_module.actor_network_params.values(True, True)) optimizer_actor = optim.Adam( actor_params, From 325104c7b572a83e4f9cac1cdcd733e47b350469 Mon Sep 17 00:00:00 2001 From: BY571 Date: Thu, 27 Jun 2024 12:08:39 +0200 Subject: [PATCH 13/18] fixes objective --- torchrl/objectives/td3_bc.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index c53d9ad26fc..2b53464e174 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -42,17 +42,18 @@ class TD3BCLoss(LossModule): Keyword Args: bounds (tuple of float, optional): the bounds of the action space. - Exclusive with action_spec. Either this or ``action_spec`` must + Exclusive with ``action_spec``. Either this or ``action_spec`` must be provided. action_spec (TensorSpec, optional): the action spec. - Exclusive with bounds. Either this or ``bounds`` must be provided. + Exclusive with ``bounds``. Either this or ``bounds`` must be provided. num_qvalue_nets (int, optional): Number of Q-value networks to be - trained. Default is ``10``. + trained. Default is ``2``. policy_noise (float, optional): Standard deviation for the target policy action noise. Default is ``0.2``. noise_clip (float, optional): Clipping range value for the sampled target policy action noise. Default is ``0.5``. alpha (float, optional): Weight for the behavioral cloning loss. + Defaults to ``2.5``. priority_key (str, optional): Key where to write the priority value for prioritized replay buffers. Default is `"td_error"`. From 9a3c393d579e94564cb010eea4e5ac99321c0228 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Fri, 5 Jul 2024 17:51:11 +0100 Subject: [PATCH 14/18] amend --- .github/unittest/linux_examples/scripts/run_test.sh | 6 +++--- sota-implementations/td3_bc/config.yaml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/unittest/linux_examples/scripts/run_test.sh b/.github/unittest/linux_examples/scripts/run_test.sh index fd1b60cda3a..a984b37faa9 100755 --- a/.github/unittest/linux_examples/scripts/run_test.sh +++ b/.github/unittest/linux_examples/scripts/run_test.sh @@ -51,12 +51,12 @@ python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/iq python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/cql/cql_offline.py \ optim.gradient_steps=55 \ logger.backend= -python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3_bc/td3_bc.py \ - optim.gradient_steps=55 \ - logger.backend= # ==================================================================================== # # ================================ Gymnasium ========================================= # +python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/td3_bc/td3_bc.py \ + optim.gradient_steps=55 \ + logger.backend= python .github/unittest/helpers/coverage_run_parallel.py sota-implementations/impala/impala_single_node.py \ collector.total_frames=80 \ collector.frames_per_batch=20 \ diff --git a/sota-implementations/td3_bc/config.yaml b/sota-implementations/td3_bc/config.yaml index bf93b20d082..54275a94bc2 100644 --- a/sota-implementations/td3_bc/config.yaml +++ b/sota-implementations/td3_bc/config.yaml @@ -2,7 +2,7 @@ env: name: HalfCheetah-v4 # Use v4 to get rid of mujoco-py dependency task: "" - library: gym + library: gymnasium seed: 42 max_episode_steps: 1000 From 5476e36e2bd2dd4949a55f10ebd0606c6b5f3e28 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 9 Jul 2024 15:17:14 +0200 Subject: [PATCH 15/18] remove gamma from cost constructor --- torchrl/objectives/td3_bc.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index 2b53464e174..c136fca5f3f 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -19,7 +19,6 @@ from torchrl.objectives.utils import ( _cache_values, - _GAMMA_LMBDA_DEPREC_ERROR, _reduce, _vmap_func, default_value_kwargs, @@ -238,7 +237,6 @@ def __init__( loss_function: str = "smooth_l1", delay_actor: bool = True, delay_qvalue: bool = True, - gamma: float = None, priority_key: str = None, separate_losses: bool = False, reduction: str = None, @@ -319,8 +317,6 @@ def __init__( high = high.to(device) self.register_buffer("max_action", high) self.register_buffer("min_action", low) - if gamma is not None: - raise TypeError(_GAMMA_LMBDA_DEPREC_ERROR) self._vmap_qvalue_network00 = _vmap_func( self.qvalue_network, randomness=self.vmap_randomness ) From f93e007f92b391286789f33b7f8622a56c70972b Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 9 Jul 2024 15:27:09 +0200 Subject: [PATCH 16/18] Update test/test_cost.py Co-authored-by: Vincent Moens --- test/test_cost.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/test_cost.py b/test/test_cost.py index 24d22ec4902..2f187c8e3ba 100644 --- a/test/test_cost.py +++ b/test/test_cost.py @@ -2715,9 +2715,6 @@ def test_td3_reduction(self, reduction): assert loss[key].shape == torch.Size([]) -@pytest.mark.skipif( - not _has_functorch, reason=f"functorch not installed: {FUNCTORCH_ERR}" -) class TestTD3BC(LossModuleTestBase): seed = 0 From 235ca515e83d1d56d467c7d401d151d89c9dd7f3 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 9 Jul 2024 18:22:32 +0200 Subject: [PATCH 17/18] update docstrings --- torchrl/objectives/td3_bc.py | 45 ++++++++++++++++++++++++++++++++---- 1 file changed, 40 insertions(+), 5 deletions(-) diff --git a/torchrl/objectives/td3_bc.py b/torchrl/objectives/td3_bc.py index c136fca5f3f..93845bb00bd 100644 --- a/torchrl/objectives/td3_bc.py +++ b/torchrl/objectives/td3_bc.py @@ -32,7 +32,14 @@ class TD3BCLoss(LossModule): r"""TD3+BC Loss Module. Implementation of the TD3+BC loss presented in the paper `"A Minimalist Approach to - Offline Reinforcement Learning" ` + Offline Reinforcement Learning" `. + + This class incorporates two loss functions, executed sequentially within the `forward` method: + + 1. :meth:`~.qvalue_loss` + 2. :meth:`~.actor_loss` + + Users also have the option to call these functions directly in the same order if preferred. Args: actor_network (TensorDictModule): the actor to be trained @@ -370,6 +377,17 @@ def _cached_stack_actor_params(self): ) def actor_loss(self, tensordict): + """Compute the actor loss. + + The actor loss should be computed after the :meth:`~.qvalue_loss` and is usually delayed 1-3 critic updates. + + Args: + tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields + are required for this to be computed. + Returns: a differentiable tensor with the actor loss along with a metadata dictionary containing the detached `"bc_loss"` + used in the combined actor loss as well as the detached `"state_action_value_actor"` used to calculate the lambda + value, and the lambda value `"lmbd"` itself. + """ tensordict_actor_grad = tensordict.select( *self.actor_network.in_keys, strict=False ) @@ -398,14 +416,24 @@ def actor_loss(self, tensordict): loss_actor = -lmbd * state_action_value_actor[0] + bc_loss metadata = { - "state_action_value_actor": state_action_value_actor.detach(), - "bc_loss": bc_loss, + "state_action_value_actor": state_action_value_actor[0].detach(), + "bc_loss": bc_loss.detach(), "lmbd": lmbd, } loss_actor = _reduce(loss_actor, reduction=self.reduction) return loss_actor, metadata - def value_loss(self, tensordict): + def qvalue_loss(self, tensordict): + """Compute the q-value loss. + + The q-value loss should be computed before the :meth:`~.actor_loss`. + + Args: + tensordict (TensorDictBase): the input data for the loss. Check the class's `in_keys` to see what fields + are required for this to be computed. + Returns: a differentiable tensor with the qvalue loss along with a metadata dictionary containing + the detached `"td_error"` to be used for prioritized sampling, the detached `"next_state_value"`, the detached `"pred_value"`, and the detached `"target_value"`. + """ tensordict = tensordict.clone(False) act = tensordict.get(self.tensor_keys.action) @@ -484,9 +512,16 @@ def value_loss(self, tensordict): @dispatch def forward(self, tensordict: TensorDictBase) -> TensorDictBase: + """The forward method. + + Computes successively the :meth:`~.actor_loss`, :meth:`~.qvalue_loss`, and returns + a tensordict with these values. + To see what keys are expected in the input tensordict and what keys are expected as output, check the + class's `"in_keys"` and `"out_keys"` attributes. + """ tensordict_save = tensordict loss_actor, metadata_actor = self.actor_loss(tensordict) - loss_qval, metadata_value = self.value_loss(tensordict_save) + loss_qval, metadata_value = self.qvalue_loss(tensordict_save) tensordict_save.set( self.tensor_keys.priority, metadata_value.pop("td_error").detach().max(0)[0] ) From acc9baf52e8eced45cefd35c75fccc4ca6d76cf7 Mon Sep 17 00:00:00 2001 From: BY571 Date: Tue, 9 Jul 2024 18:26:59 +0200 Subject: [PATCH 18/18] update example script --- sota-implementations/td3_bc/td3_bc.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sota-implementations/td3_bc/td3_bc.py b/sota-implementations/td3_bc/td3_bc.py index 46fb923403e..7c43fdc1a12 100644 --- a/sota-implementations/td3_bc/td3_bc.py +++ b/sota-implementations/td3_bc/td3_bc.py @@ -103,7 +103,7 @@ def main(cfg: "DictConfig"): # noqa: F821 sampled_tensordict = sampled_tensordict.clone() # Compute loss - q_loss, *_ = loss_module.value_loss(sampled_tensordict) + q_loss, *_ = loss_module.qvalue_loss(sampled_tensordict) # Update critic optimizer_critic.zero_grad()