diff --git a/README.md b/README.md index 3e4c607e..815956d6 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ See documentation for the full list of included features. - [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269) - [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044) - [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171) +- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477) **Gym Wrappers**: - [Time Feature Wrapper](https://arxiv.org/abs/1712.00378) diff --git a/docs/common/utils.rst b/docs/common/utils.rst new file mode 100644 index 00000000..79fd9bf2 --- /dev/null +++ b/docs/common/utils.rst @@ -0,0 +1,7 @@ +.. _utils: + +Utils +===== + +.. automodule:: sb3_contrib.common.utils + :members: diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 879a84e1..81770b06 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -9,7 +9,8 @@ along with some useful characteristics: support for discrete/continuous actions, Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing ============ =========== ============ ================= =============== ================ TQC ✔️ ❌ ❌ ❌ ✔️ -QR-DQN ️❌ ️✔️ ❌ ❌ ✔️ +TRPO ✔️ ✔️ ✔️ ✔️ ✔️ +QR-DQN ️❌ ️✔️ ❌ ❌ ✔️ ============ =========== ============ ================= =============== ================ diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 7d707888..3af55613 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -44,3 +44,16 @@ Train a PPO with invalid action masking agent on a toy environment. model = MaskablePPO("MlpPolicy", env, verbose=1) model.learn(5000) model.save("qrdqn_cartpole") + + TRPO + ---- + + Train a Trust Region Policy Optimization (TRPO) agent on the Pendulum environment. + + .. code-block:: python + + from sb3_contrib import TRPO + + model = TRPO("MlpPolicy", "Pendulum-v0", gamma=0.9, verbose=1) + model.learn(total_timesteps=100_000, log_interval=4) + model.save("trpo_pendulum") diff --git a/docs/index.rst b/docs/index.rst index 8e37c71b..ac610bd0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d :caption: RL Algorithms modules/tqc + modules/trpo modules/qrdqn modules/ppo_mask @@ -39,6 +40,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d :maxdepth: 1 :caption: Common + common/utils common/wrappers .. toctree:: diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 0723975f..1b117e6b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,8 +4,9 @@ Changelog ========== -Release 1.3.1a6 (WIP) +Release 1.3.1a7 (WIP) ------------------------------- +**Add TRPO** Breaking Changes: ^^^^^^^^^^^^^^^^^ @@ -15,6 +16,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- Added ``TRPO`` (@cyprienc) - Added experimental support to train off-policy algorithms with multiple envs (note: ``HerReplayBuffer`` currently not supported) Bug Fixes: @@ -34,7 +36,7 @@ Documentation: Release 1.3.0 (2021-10-23) ------------------------------- -**Invalid action masking for PPO** +**Add Invalid action masking for PPO** .. warning:: @@ -52,6 +54,7 @@ New Features: - Added ``MaskablePPO`` algorithm (@kronion) - ``MaskablePPO`` Dictionary Observation support (@glmcdona) + Bug Fixes: ^^^^^^^^^^ @@ -75,9 +78,6 @@ Breaking Changes: ^^^^^^^^^^^^^^^^^ - Upgraded to Stable-Baselines3 >= 1.2.0 -New Features: -^^^^^^^^^^^^^ - Bug Fixes: ^^^^^^^^^^ - QR-DQN and TQC updated so that their policies are switched between train and eval mode at the correct time (@ayeright) @@ -221,4 +221,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_) Contributors: ------------- -@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona +@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc diff --git a/docs/modules/trpo.rst b/docs/modules/trpo.rst new file mode 100644 index 00000000..4ff88281 --- /dev/null +++ b/docs/modules/trpo.rst @@ -0,0 +1,151 @@ +.. _tqc: + +.. automodule:: sb3_contrib.trpo + +TRPO +==== + +`Trust Region Policy Optimization (TRPO) `_ +is an iterative approach for optimizing policies with guaranteed monotonic improvement. + +.. rubric:: Available Policies + +.. autosummary:: + :nosignatures: + + MlpPolicy + CnnPolicy + MultiInputPolicy + + +Notes +----- + +- Original paper: https://arxiv.org/abs/1502.05477 +- OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/ + + +Can I use? +---------- + +- Recurrent policies: ❌ +- Multi processing: ✔️ +- Gym spaces: + + +============= ====== =========== +Space Action Observation +============= ====== =========== +Discrete ✔️ ✔️ +Box ✔️ ✔️ +MultiDiscrete ✔️ ✔️ +MultiBinary ✔️ ✔️ +Dict ❌ ✔️ +============= ====== =========== + + +Example +------- + +.. code-block:: python + + import gym + import numpy as np + + from sb3_contrib import TRPO + + env = gym.make("Pendulum-v0") + + model = TRPO("MlpPolicy", env, verbose=1) + model.learn(total_timesteps=10000, log_interval=4) + model.save("trpo_pendulum") + + del model # remove to demonstrate saving and loading + + model = TRPO.load("trpo_pendulum") + + obs = env.reset() + while True: + action, _states = model.predict(obs, deterministic=True) + obs, reward, done, info = env.step(action) + env.render() + if done: + obs = env.reset() + + +Results +------- + +Result on the MuJoCo benchmark (1M steps on ``-v3`` envs with MuJoCo v2.1.0) using 3 seeds. +The complete learning curves are available in the `associated PR `_. + + +===================== ============ +Environments TRPO +===================== ============ +HalfCheetah 1803 +/- 46 +Ant 3554 +/- 591 +Hopper 3372 +/- 215 +Walker2d 4502 +/- 234 +Swimmer 359 +/- 2 +===================== ============ + + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone RL-Zoo and checkout the branch ``feat/trpo``: + +.. code-block:: bash + + git clone https://github.com/cyprienc/rl-baselines3-zoo + cd rl-baselines3-zoo/ + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo tqc --env $ENV_ID --n-eval-envs 10 --eval-episodes 20 --eval-freq 50000 + + +Plot the results: + +.. code-block:: bash + + python scripts/all_plots.py -a trpo -e HalfCheetah Ant Hopper Walker2d Swimmer -f logs/ -o logs/trpo_results + python scripts/plot_from_file.py -i logs/trpo_results.pkl -latex -l TRPO + + +Parameters +---------- + +.. autoclass:: TRPO + :members: + :inherited-members: + +.. _trpo_policies: + +TRPO Policies +------------- + +.. autoclass:: MlpPolicy + :members: + :inherited-members: + +.. autoclass:: stable_baselines3.common.policies.ActorCriticPolicy + :members: + :noindex: + +.. autoclass:: CnnPolicy + :members: + +.. autoclass:: stable_baselines3.common.policies.ActorCriticCnnPolicy + :members: + :noindex: + +.. autoclass:: MultiInputPolicy + :members: + +.. autoclass:: stable_baselines3.common.policies.MultiInputActorCriticPolicy + :members: + :noindex: diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index c90336af..790eaaa7 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -3,6 +3,7 @@ from sb3_contrib.ppo_mask import MaskablePPO from sb3_contrib.qrdqn import QRDQN from sb3_contrib.tqc import TQC +from sb3_contrib.trpo import TRPO # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") diff --git a/sb3_contrib/common/utils.py b/sb3_contrib/common/utils.py index 4a9e522d..daa6863b 100644 --- a/sb3_contrib/common/utils.py +++ b/sb3_contrib/common/utils.py @@ -1,6 +1,7 @@ -from typing import Optional +from typing import Callable, Optional, Sequence import torch as th +from torch import nn def quantile_huber_loss( @@ -67,3 +68,96 @@ def quantile_huber_loss( else: loss = loss.mean() return loss + + +def conjugate_gradient_solver( + matrix_vector_dot_fn: Callable[[th.Tensor], th.Tensor], + b, + max_iter=10, + residual_tol=1e-10, +) -> th.Tensor: + """ + Finds an approximate solution to a set of linear equations Ax = b + + Sources: + - https://github.com/ajlangley/trpo-pytorch/blob/master/conjugate_gradient.py + - https://github.com/joschu/modular_rl/blob/master/modular_rl/trpo.py#L122 + + Reference: + - https://epubs.siam.org/doi/abs/10.1137/1.9781611971446.ch6 + + :param matrix_vector_dot_fn: + a function that right multiplies a matrix A by a vector v + :param b: + the right hand term in the set of linear equations Ax = b + :param max_iter: + the maximum number of iterations (default is 10) + :param residual_tol: + residual tolerance for early stopping of the solving (default is 1e-10) + :return x: + the approximate solution to the system of equations defined by `matrix_vector_dot_fn` + and b + """ + + # The vector is not initialized at 0 because of the instability issues when the gradient becomes small. + # A small random gaussian noise is used for the initialization. + x = 1e-4 * th.randn_like(b) + residual = b - matrix_vector_dot_fn(x) + # Equivalent to th.linalg.norm(residual) ** 2 (L2 norm squared) + residual_squared_norm = th.matmul(residual, residual) + + if residual_squared_norm < residual_tol: + # If the gradient becomes extremely small + # The denominator in alpha will become zero + # Leading to a division by zero + return x + + p = residual.clone() + + for i in range(max_iter): + # A @ p (matrix vector multiplication) + A_dot_p = matrix_vector_dot_fn(p) + + alpha = residual_squared_norm / p.dot(A_dot_p) + x += alpha * p + + if i == max_iter - 1: + return x + + residual -= alpha * A_dot_p + new_residual_squared_norm = th.matmul(residual, residual) + + if new_residual_squared_norm < residual_tol: + return x + + beta = new_residual_squared_norm / residual_squared_norm + residual_squared_norm = new_residual_squared_norm + p = residual + beta * p + + +def flat_grad( + output, + parameters: Sequence[nn.parameter.Parameter], + create_graph: bool = False, + retain_graph: bool = False, +) -> th.Tensor: + """ + Returns the gradients of the passed sequence of parameters into a flat gradient. + Order of parameters is preserved. + + :param output: functional output to compute the gradient for + :param parameters: sequence of ``Parameter`` + :param retain_graph: – If ``False``, the graph used to compute the grad will be freed. + Defaults to the value of ``create_graph``. + :param create_graph: – If ``True``, graph of the derivative will be constructed, + allowing to compute higher order derivative products. Default: ``False``. + :return: Tensor containing the flattened gradients + """ + grads = th.autograd.grad( + output, + parameters, + create_graph=create_graph, + retain_graph=retain_graph, + allow_unused=True, + ) + return th.cat([th.ravel(grad) for grad in grads if grad is not None]) diff --git a/sb3_contrib/trpo/__init__.py b/sb3_contrib/trpo/__init__.py new file mode 100644 index 00000000..7465a9d9 --- /dev/null +++ b/sb3_contrib/trpo/__init__.py @@ -0,0 +1,2 @@ +from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy +from sb3_contrib.trpo.trpo import TRPO diff --git a/sb3_contrib/trpo/policies.py b/sb3_contrib/trpo/policies.py new file mode 100644 index 00000000..27cde537 --- /dev/null +++ b/sb3_contrib/trpo/policies.py @@ -0,0 +1,16 @@ +# This file is here just to define MlpPolicy/CnnPolicy +# that work for TRPO +from stable_baselines3.common.policies import ( + ActorCriticCnnPolicy, + ActorCriticPolicy, + MultiInputActorCriticPolicy, + register_policy, +) + +MlpPolicy = ActorCriticPolicy +CnnPolicy = ActorCriticCnnPolicy +MultiInputPolicy = MultiInputActorCriticPolicy + +register_policy("MlpPolicy", ActorCriticPolicy) +register_policy("CnnPolicy", ActorCriticCnnPolicy) +register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py new file mode 100644 index 00000000..9c461031 --- /dev/null +++ b/sb3_contrib/trpo/trpo.py @@ -0,0 +1,421 @@ +import copy +import warnings +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import numpy as np +import torch as th +from gym import spaces +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutBufferSamples, Schedule +from stable_baselines3.common.utils import explained_variance +from torch import nn +from torch.distributions import kl_divergence +from torch.nn import functional as F + +from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad + + +class TRPO(OnPolicyAlgorithm): + """ + Trust Region Policy Optimization (TRPO) + + Paper: https://arxiv.org/abs/1502.05477 + Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/) + and Stable Baselines (TRPO from https://github.com/hill-a/stable-baselines) + + Introduction to TRPO: https://spinningup.openai.com/en/latest/algorithms/trpo.html + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate for the value function, it can be a function + of the current progress remaining (from 1 to 0) + :param n_steps: The number of steps to run for each environment per update + (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel) + NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization) + See https://github.com/pytorch/pytorch/issues/29372 + :param batch_size: Minibatch size for the value function + :param gamma: Discount factor + :param cg_max_steps: maximum number of steps in the Conjugate Gradient algorithm + for computing the Hessian vector product + :param cg_damping: damping in the Hessian vector product computation + :param line_search_shrinking_factor: step-size reduction factor for the line-search + (i.e., ``theta_new = theta + alpha^i * step``) + :param line_search_max_iter: maximum number of iteration + for the backtracking line-search + :param n_critic_updates: number of critic updates per policy update + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) + instead of action noise exploration (default: False) + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE + Default: -1 (only sample at the beginning of the rollout) + :param normalize_advantage: Whether to normalize or not the advantage + :param target_kl: Target Kullback-Leibler divergence between updates. + Should be small for stability. Values like 0.01, 0.05. + :param sub_sampling_factor: Sub-sample the batch to make computation faster + see p40-42 of John Schulman thesis http://joschu.net/docs/thesis.pdf + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param create_eval_env: Whether to create a second environment that will be + used for evaluating the agent periodically. (Only available when passing string for the environment) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + def __init__( + self, + policy: Union[str, Type[ActorCriticPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 1e-3, + n_steps: int = 2048, + batch_size: int = 128, + gamma: float = 0.99, + cg_max_steps: int = 15, + cg_damping: float = 0.1, + line_search_shrinking_factor: float = 0.8, + line_search_max_iter: int = 10, + n_critic_updates: int = 10, + gae_lambda: float = 0.95, + use_sde: bool = False, + sde_sample_freq: int = -1, + normalize_advantage: bool = True, + target_kl: float = 0.01, + sub_sampling_factor: int = 1, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + + super(TRPO, self).__init__( + policy, + env, + learning_rate=learning_rate, + n_steps=n_steps, + gamma=gamma, + gae_lambda=gae_lambda, + ent_coef=0.0, # entropy bonus is not used by TRPO + vf_coef=0.0, # value function is optimized separately + max_grad_norm=0.0, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + policy_base=ActorCriticPolicy, + tensorboard_log=tensorboard_log, + policy_kwargs=policy_kwargs, + verbose=verbose, + device=device, + create_eval_env=create_eval_env, + seed=seed, + _init_setup_model=False, + supported_action_spaces=( + spaces.Box, + spaces.Discrete, + spaces.MultiDiscrete, + spaces.MultiBinary, + ), + ) + + self.normalize_advantage = normalize_advantage + # Sanity check, otherwise it will lead to noisy gradient and NaN + # because of the advantage normalization + if self.env is not None: + # Check that `n_steps * n_envs > 1` to avoid NaN + # when doing advantage normalization + buffer_size = self.env.num_envs * self.n_steps + if normalize_advantage: + assert buffer_size > 1, ( + "`n_steps * n_envs` must be greater than 1. " + f"Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" + ) + # Check that the rollout buffer size is a multiple of the mini-batch size + untruncated_batches = buffer_size // batch_size + if buffer_size % batch_size > 0: + warnings.warn( + f"You have specified a mini-batch size of {batch_size}," + f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`," + f" after every {untruncated_batches} untruncated mini-batches," + f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n" + f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n" + f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})" + ) + self.batch_size = batch_size + # Conjugate gradients parameters + self.cg_max_steps = cg_max_steps + self.cg_damping = cg_damping + # Backtracking line search parameters + self.line_search_shrinking_factor = line_search_shrinking_factor + self.line_search_max_iter = line_search_max_iter + self.target_kl = target_kl + self.n_critic_updates = n_critic_updates + self.sub_sampling_factor = sub_sampling_factor + + if _init_setup_model: + self._setup_model() + + def _compute_actor_grad( + self, kl_div: th.Tensor, policy_objective: th.Tensor + ) -> Tuple[List[nn.Parameter], th.Tensor, th.Tensor, List[Tuple[int, ...]]]: + """ + Compute actor gradients for kl div and surrogate objectives. + + :param kl_div: The KL divergence objective + :param policy_objective: The surrogate objective ("classic" policy gradient) + :return: List of actor params, gradients and gradients shape. + """ + # This is necessary because not all the parameters in the policy have gradients w.r.t. the KL divergence + # The policy objective is also called surrogate objective + policy_objective_gradients = [] + # Contains the gradients of the KL divergence + grad_kl = [] + # Contains the shape of the gradients of the KL divergence w.r.t each parameter + # This way the flattened gradient can be reshaped back into the original shapes and applied to + # the parameters + grad_shape = [] + # Contains the parameters which have non-zeros KL divergence gradients + # The list is used during the line-search to apply the step to each parameters + actor_params = [] + + for name, param in self.policy.named_parameters(): + # Skip parameters related to value function based on name + # this work for built-in policies only (not custom ones) + if "value" in name: + continue + + # For each parameter we compute the gradient of the KL divergence w.r.t to that parameter + kl_param_grad, *_ = th.autograd.grad( + kl_div, + param, + create_graph=True, + retain_graph=True, + allow_unused=True, + only_inputs=True, + ) + # If the gradient is not zero (not None), we store the parameter in the actor_params list + # and add the gradient and its shape to grad_kl and grad_shape respectively + if kl_param_grad is not None: + # If the parameter impacts the KL divergence (i.e. the policy) + # we compute the gradient of the policy objective w.r.t to the parameter + # this avoids computing the gradient if it's not going to be used in the conjugate gradient step + policy_objective_grad, *_ = th.autograd.grad(policy_objective, param, retain_graph=True, only_inputs=True) + + grad_shape.append(kl_param_grad.shape) + grad_kl.append(kl_param_grad.view(-1)) + policy_objective_gradients.append(policy_objective_grad.view(-1)) + actor_params.append(param) + + # Gradients are concatenated before the conjugate gradient step + policy_objective_gradients = th.cat(policy_objective_gradients) + grad_kl = th.cat(grad_kl) + return actor_params, policy_objective_gradients, grad_kl, grad_shape + + def train(self) -> None: + """ + Update policy using the currently gathered rollout buffer. + """ + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) + # Update optimizer learning rate + self._update_learning_rate(self.policy.optimizer) + + policy_objective_values = [] + kl_divergences = [] + line_search_results = [] + value_losses = [] + + # This will only loop once (get all data in one go) + for rollout_data in self.rollout_buffer.get(batch_size=None): + + # Optional: sub-sample data for faster computation + if self.sub_sampling_factor > 1: + rollout_data = RolloutBufferSamples( + rollout_data.observations[:: self.sub_sampling_factor], + rollout_data.actions[:: self.sub_sampling_factor], + None, # old values, not used here + rollout_data.old_log_prob[:: self.sub_sampling_factor], + rollout_data.advantages[:: self.sub_sampling_factor], + None, # returns, not used here + ) + + actions = rollout_data.actions + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action from float to long + actions = rollout_data.actions.long().flatten() + + # Re-sample the noise matrix because the log_std has changed + if self.use_sde: + # batch_size is only used for the value function + self.policy.reset_noise(actions.shape[0]) + + with th.no_grad(): + # Note: is copy enough, no need for deepcopy? + # If using gSDE and deepcopy, we need to use `old_distribution.distribution` + # directly to avoid PyTorch errors. + old_distribution = copy.copy(self.policy.get_distribution(rollout_data.observations)) + + distribution = self.policy.get_distribution(rollout_data.observations) + log_prob = distribution.log_prob(actions) + + advantages = rollout_data.advantages + if self.normalize_advantage: + advantages = (advantages - advantages.mean()) / (rollout_data.advantages.std() + 1e-8) + + # ratio between old and new policy, should be one at the first iteration + ratio = th.exp(log_prob - rollout_data.old_log_prob) + + # surrogate policy objective + policy_objective = (advantages * ratio).mean() + + # KL divergence + kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean() + + # Surrogate & KL gradient + self.policy.optimizer.zero_grad() + + actor_params, policy_objective_gradients, grad_kl, grad_shape = self._compute_actor_grad(kl_div, policy_objective) + + # Hessian-vector dot product function used in the conjugate gradient step + hessian_vector_product_fn = partial(self.hessian_vector_product, actor_params, grad_kl) + + # Computing search direction + search_direction = conjugate_gradient_solver( + hessian_vector_product_fn, + policy_objective_gradients, + max_iter=self.cg_max_steps, + ) + + # Maximal step length + line_search_max_step_size = 2 * self.target_kl + line_search_max_step_size /= th.matmul( + search_direction, hessian_vector_product_fn(search_direction, retain_graph=False) + ) + line_search_max_step_size = th.sqrt(line_search_max_step_size) + + line_search_backtrack_coeff = 1.0 + original_actor_params = [param.detach().clone() for param in actor_params] + + is_line_search_success = False + with th.no_grad(): + # Line-search (backtracking) + for _ in range(self.line_search_max_iter): + + start_idx = 0 + # Applying the scaled step direction + for param, original_param, shape in zip(actor_params, original_actor_params, grad_shape): + n_params = param.numel() + param.data = ( + original_param.data + + line_search_backtrack_coeff + * line_search_max_step_size + * search_direction[start_idx : (start_idx + n_params)].view(shape) + ) + start_idx += n_params + + # Recomputing the policy log-probabilities + distribution = self.policy.get_distribution(rollout_data.observations) + log_prob = distribution.log_prob(actions) + + # New policy objective + ratio = th.exp(log_prob - rollout_data.old_log_prob) + new_policy_objective = (advantages * ratio).mean() + + # New KL-divergence + kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean() + + # Constraint criteria: + # we need to improve the surrogate policy objective + # while being close enough (in term of kl div) to the old policy + if (kl_div < self.target_kl) and (new_policy_objective > policy_objective): + is_line_search_success = True + break + + # Reducing step size if line-search wasn't successful + line_search_backtrack_coeff *= self.line_search_shrinking_factor + + line_search_results.append(is_line_search_success) + + if not is_line_search_success: + # If the line-search wasn't successful we revert to the original parameters + for param, original_param in zip(actor_params, original_actor_params): + param.data = original_param.data.clone() + + policy_objective_values.append(policy_objective.item()) + kl_divergences.append(0) + else: + policy_objective_values.append(new_policy_objective.item()) + kl_divergences.append(kl_div.item()) + + # Critic update + for _ in range(self.n_critic_updates): + for rollout_data in self.rollout_buffer.get(self.batch_size): + values_pred = self.policy.predict_values(rollout_data.observations) + value_loss = F.mse_loss(rollout_data.returns, values_pred.flatten()) + value_losses.append(value_loss.item()) + + self.policy.optimizer.zero_grad() + value_loss.backward() + # Removing gradients of parameters shared with the actor + # otherwise it defeats the purposes of the KL constraint + for param in actor_params: + param.grad = None + self.policy.optimizer.step() + + self._n_updates += 1 + explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) + + # Logs + self.logger.record("train/policy_objective", np.mean(policy_objective_values)) + self.logger.record("train/value_loss", np.mean(value_losses)) + self.logger.record("train/kl_divergence_loss", np.mean(kl_divergences)) + self.logger.record("train/explained_variance", explained_var) + self.logger.record("train/is_line_search_success", np.mean(line_search_results)) + if hasattr(self.policy, "log_std"): + self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) + + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + + def hessian_vector_product( + self, params: List[nn.Parameter], grad_kl: th.Tensor, vector: th.Tensor, retain_graph: bool = True + ) -> th.Tensor: + """ + Computes the matrix-vector product with the Fisher information matrix. + + :param params: list of parameters used to compute the Hessian + :param grad_kl: flattened gradient of the KL divergence between the old and new policy + :param vector: vector to compute the dot product the hessian-vector dot product with + :param retain_graph: if True, the graph will be kept after computing the Hessian + :return: Hessian-vector dot product (with damping) + """ + jacobian_vector_product = (grad_kl * vector).sum() + return flat_grad(jacobian_vector_product, params, retain_graph=retain_graph) + self.cg_damping * vector + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "TRPO", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> OnPolicyAlgorithm: + + return super(TRPO, self).learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + eval_env=eval_env, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + tb_log_name=tb_log_name, + eval_log_path=eval_log_path, + reset_num_timesteps=reset_num_timesteps, + ) diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index e6eaed88..f6258073 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.3.1a6 +1.3.1a7 diff --git a/setup.cfg b/setup.cfg index cf162f7b..dc9cf363 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,6 +25,7 @@ per-file-ignores = ./sb3_contrib/ppo_mask/__init__.py:F401 ./sb3_contrib/qrdqn/__init__.py:F401 ./sb3_contrib/tqc/__init__.py:F401 + ./sb3_contrib/trpo/__init__.py:F401 ./sb3_contrib/common/vec_env/wrappers/__init__.py:F401 ./sb3_contrib/common/wrappers/__init__.py:F401 ./sb3_contrib/common/envs/__init__.py:F401 diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 6c277856..e570aab6 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -8,10 +8,10 @@ from stable_baselines3.common.utils import zip_strict from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped -from sb3_contrib import QRDQN, TQC +from sb3_contrib import QRDQN, TQC, TRPO -@pytest.mark.parametrize("model_class", [TQC, QRDQN]) +@pytest.mark.parametrize("model_class", [TQC, QRDQN, TRPO]) def test_cnn(tmp_path, model_class): SAVE_NAME = "cnn_model.zip" # Fake grayscale with frameskip diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 86f5b5e0..38757ab3 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -6,7 +6,7 @@ from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize -from sb3_contrib import QRDQN, TQC +from sb3_contrib import QRDQN, TQC, TRPO class DummyDictEnv(gym.Env): @@ -78,7 +78,7 @@ def render(self, mode="human"): pass -@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO]) def test_consistency(model_class): """ Make sure that dict obs with vector only vs using flatten obs is equivalent. @@ -94,7 +94,7 @@ def test_consistency(model_class): kwargs = {} n_steps = 256 - if model_class in {}: + if model_class in {TRPO}: kwargs = dict( n_steps=128, ) @@ -124,7 +124,7 @@ def test_consistency(model_class): assert np.allclose(action_1, action_2) -@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO]) @pytest.mark.parametrize("channel_last", [False, True]) def test_dict_spaces(model_class, channel_last): """ @@ -138,11 +138,11 @@ def test_dict_spaces(model_class, channel_last): kwargs = {} n_steps = 256 - if model_class in {}: + if model_class in {TRPO}: kwargs = dict( n_steps=128, policy_kwargs=dict( - net_arch=[32], + net_arch=[dict(pi=[32], vf=[32])], features_extractor_kwargs=dict(cnn_output_dim=32), ), ) @@ -169,7 +169,7 @@ def test_dict_spaces(model_class, channel_last): evaluate_policy(model, env, n_eval_episodes=5, warn=False) -@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO]) @pytest.mark.parametrize("channel_last", [False, True]) def test_dict_vec_framestack(model_class, channel_last): """ @@ -187,11 +187,11 @@ def test_dict_vec_framestack(model_class, channel_last): kwargs = {} n_steps = 256 - if model_class in {}: + if model_class in {TRPO}: kwargs = dict( n_steps=128, policy_kwargs=dict( - net_arch=[32], + net_arch=[dict(pi=[32], vf=[32])], features_extractor_kwargs=dict(cnn_output_dim=32), ), ) @@ -218,7 +218,7 @@ def test_dict_vec_framestack(model_class, channel_last): evaluate_policy(model, env, n_eval_episodes=5, warn=False) -@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO]) def test_vec_normalize(model_class): """ Additional tests to check observation space support @@ -230,11 +230,11 @@ def test_vec_normalize(model_class): kwargs = {} n_steps = 256 - if model_class in {}: + if model_class in {TRPO}: kwargs = dict( n_steps=128, policy_kwargs=dict( - net_arch=[32], + net_arch=[dict(pi=[32], vf=[32])], ), ) else: diff --git a/tests/test_run.py b/tests/test_run.py index b53641cc..ab2d311e 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -2,7 +2,7 @@ import pytest from stable_baselines3.common.env_util import make_vec_env -from sb3_contrib import QRDQN, TQC +from sb3_contrib import QRDQN, TQC, TRPO @pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"]) @@ -60,6 +60,28 @@ def test_qrdqn(): model.learn(total_timesteps=500, eval_freq=250) +@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"]) +def test_trpo(env_id): + model = TRPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1) + model.learn(total_timesteps=500) + + +def test_trpo_params(): + # Test with gSDE and subsampling + model = TRPO( + "MlpPolicy", + "Pendulum-v0", + n_steps=64, + batch_size=32, + use_sde=True, + sub_sampling_factor=4, + seed=0, + policy_kwargs=dict(net_arch=[dict(pi=[32], vf=[32])]), + verbose=1, + ) + model.learn(total_timesteps=500) + + @pytest.mark.parametrize("model_class", [TQC, QRDQN]) def test_offpolicy_multi_env(model_class): if model_class in [TQC]: diff --git a/tests/test_save_load.py b/tests/test_save_load.py index d2ee3a21..10202e76 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -12,9 +12,9 @@ from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv -from sb3_contrib import QRDQN, TQC +from sb3_contrib import QRDQN, TQC, TRPO -MODEL_LIST = [TQC, QRDQN] +MODEL_LIST = [TQC, QRDQN, TRPO] def select_env(model_class: BaseAlgorithm) -> gym.Env: @@ -277,6 +277,11 @@ def test_save_load_policy(tmp_path, model_class, policy_str): learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), ) + else: + kwargs = dict( + n_steps=128, + policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), + ) env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN) # Reduce number of quantiles for faster tests diff --git a/tests/test_utils.py b/tests/test_utils.py index 97b740f4..c434dd23 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,10 @@ import numpy as np import pytest import torch as th +from stable_baselines3.common.utils import set_random_seed -from sb3_contrib.common.utils import quantile_huber_loss +from sb3_contrib import TRPO +from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad, quantile_huber_loss def test_quantile_huber_loss(): @@ -17,3 +19,46 @@ def test_quantile_huber_loss(): quantile_huber_loss(th.zeros(4, 4), th.zeros(3, 4)) with pytest.raises(ValueError): quantile_huber_loss(th.zeros(4, 4, 4, 4), th.zeros(4, 4, 4, 4)) + + +def test_cg(): + # Test that conjugate gradient can actually solve + # Ax = b when the A^-1 is known + set_random_seed(4) + A = th.ones(3, 3) + # Symmetric matrix + A[0, 1] = 2 + A[1, 0] = 2 + x = th.ones(3) + th.rand(3) + b = A @ x + + def matrix_vector_dot_func(vector): + return A @ vector + + x_approx = conjugate_gradient_solver(matrix_vector_dot_func, b, max_iter=5, residual_tol=1e-10) + assert th.allclose(x_approx, x) + + +def test_flat_grad(): + n_parameters = 12 # 3 * (2 * 2) + x = th.nn.Parameter(th.ones(2, 2, requires_grad=True)) + y = (x ** 2).sum() + flat_grad_out = flat_grad(y, [x, x, x]) + assert len(flat_grad_out.shape) == 1 + # dy/dx = 2 + assert th.allclose(flat_grad_out, th.ones(n_parameters) * 2) + + +def test_trpo_warnings(): + """Test that TRPO warns and errors correctly on + problematic rollout buffer sizes""" + + # Only 1 step: advantage normalization will return NaN + with pytest.raises(AssertionError): + TRPO("MlpPolicy", "Pendulum-v0", n_steps=1) + # One step not advantage normalization: ok + TRPO("MlpPolicy", "Pendulum-v0", n_steps=1, normalize_advantage=False, batch_size=1) + + # Truncated mini-batch + with pytest.warns(UserWarning): + TRPO("MlpPolicy", "Pendulum-v0", n_steps=6, batch_size=8)