From f779a9fd72a98edcf255e2778f677f5f23b5b4bd Mon Sep 17 00:00:00 2001 From: Cyprien C Date: Sat, 7 Aug 2021 17:58:06 +0100 Subject: [PATCH 01/28] Feat: adding TRPO algorithm (WIP) WIP - Trust Region Policy Algorithm Currently the Hessian vector product is not working (see inline comments for more detail) --- sb3_contrib/common/utils.py | 66 +++++++- sb3_contrib/trpo/__init__.py | 0 sb3_contrib/trpo/policies.py | 16 ++ sb3_contrib/trpo/trpo.py | 315 +++++++++++++++++++++++++++++++++++ 4 files changed, 396 insertions(+), 1 deletion(-) create mode 100644 sb3_contrib/trpo/__init__.py create mode 100644 sb3_contrib/trpo/policies.py create mode 100644 sb3_contrib/trpo/trpo.py diff --git a/sb3_contrib/common/utils.py b/sb3_contrib/common/utils.py index 4a9e522d..44892863 100644 --- a/sb3_contrib/common/utils.py +++ b/sb3_contrib/common/utils.py @@ -1,6 +1,7 @@ -from typing import Optional +from typing import Optional, Sequence, Callable import torch as th +from torch import nn def quantile_huber_loss( @@ -67,3 +68,66 @@ def quantile_huber_loss( else: loss = loss.mean() return loss + + +# TODO: write regression tests +def cg_solver(Avp_fun: Callable[[th.Tensor], th.Tensor], b, max_iter=15) -> th.Tensor: + """ + Finds an approximate solution to a set of linear equations Ax = b + + Source: https://github.com/ajlangley/trpo-pytorch/blob/master/conjugate_gradient.py + + :param Avp_fun : callable + a function that right multiplies a matrix A by a vector v + :param b : torch.FloatTensor + the right hand term in the set of linear equations Ax = b + :param max_iter : int + the maximum number of iterations (default is 10) + :return x : torch.FloatTensor + the approximate solution to the system of equations defined by Avp_fun + and b + """ + + x = th.zeros_like(b) + r = b.clone() + p = b.clone() + + for i in range(max_iter): + Avp = Avp_fun(p) + + r_dot = th.matmul(r, r) + alpha = r_dot / th.matmul(p, Avp) + x += alpha * p + + if i == max_iter - 1: + return x + + r_new = r - alpha * Avp + beta = th.matmul(r_new, r_new) / r_dot + r = r_new + p = r + beta * p + + +# TODO: test +def flat_grad( + output, + parameters: Sequence[nn.parameter.Parameter], + create_graph: bool = False, + retain_graph: bool = False, +) -> th.Tensor: + """ + Returns the gradients of the passed sequence of parameters into a flat gradient. + Order of parameters is preserved. + + :param output: functional output to compute the gradient for + :param parameters: sequence of `Parameter` + :param retain_graph – If ``False``, the graph used to compute the grad will be freed. + Defaults to the value of ``create_graph``. + :param create_graph – If ``True``, graph of the derivative will be constructed, + allowing to compute higher order derivative products. Default: ``False``. + :return: Tensor containing the flattened gradients + """ + grads = th.autograd.grad( + output, parameters, create_graph=create_graph, retain_graph=retain_graph, allow_unused=True + ) + return th.cat([grad.view(-1) for grad in grads if grad is not None]) diff --git a/sb3_contrib/trpo/__init__.py b/sb3_contrib/trpo/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/sb3_contrib/trpo/policies.py b/sb3_contrib/trpo/policies.py new file mode 100644 index 00000000..7427cfc4 --- /dev/null +++ b/sb3_contrib/trpo/policies.py @@ -0,0 +1,16 @@ +# This file is here just to define MlpPolicy/CnnPolicy +# that work for PPO +from stable_baselines3.common.policies import ( + ActorCriticCnnPolicy, + ActorCriticPolicy, + MultiInputActorCriticPolicy, + register_policy, +) + +MlpPolicy = ActorCriticPolicy +CnnPolicy = ActorCriticCnnPolicy +MultiInputPolicy = MultiInputActorCriticPolicy + +register_policy("MlpPolicy", ActorCriticPolicy) +register_policy("CnnPolicy", ActorCriticCnnPolicy) +register_policy("MultiInputPolicy", MultiInputPolicy) diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py new file mode 100644 index 00000000..a6a6c761 --- /dev/null +++ b/sb3_contrib/trpo/trpo.py @@ -0,0 +1,315 @@ +import warnings +from typing import Any, Dict, Optional, Type, Union + +import numpy as np +import torch +import torch as th +from gym import spaces +from torch.nn import functional as F + +from sb3_contrib.common.utils import flat_grad, cg_solver +from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.policies import ActorCriticPolicy +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.utils import explained_variance + + +class TRPO(OnPolicyAlgorithm): + """ + Trust Region Policy Optimization (TRPO) + + Paper: https://arxiv.org/abs/1502.05477 + Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/) + and Stable Baselines (TRPO from https://github.com/hill-a/stable-baselines) + + Introduction to TRPO: https://spinningup.openai.com/en/latest/algorithms/trpo.html + + :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) + :param env: The environment to learn from (if registered in Gym, can be str) + :param learning_rate: The learning rate, it can be a function + of the current progress remaining (from 1 to 0) + :param n_steps: The number of steps to run for each environment per update + (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel) + NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization) + See https://github.com/pytorch/pytorch/issues/29372 + :param batch_size: Minibatch size + :param n_epochs: Number of epoch when optimizing the surrogate loss + :param gamma: Discount factor + :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator + :param ent_coef: Entropy coefficient for the loss calculation + :param vf_coef: Value function coefficient for the loss calculation + :param max_grad_norm: The maximum value for the gradient clipping + :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) + instead of action noise exploration (default: False) + :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE + Default: -1 (only sample at the beginning of the rollout) + :param target_kl: Limit the KL divergence between updates, + because the clipping is not enough to prevent large update + see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) + By default, there is no limit on the kl div. + :param tensorboard_log: the log location for tensorboard (if None, no logging) + :param create_eval_env: Whether to create a second environment that will be + used for evaluating the agent periodically. (Only available when passing string for the environment) + :param policy_kwargs: additional arguments to be passed to the policy on creation + :param verbose: the verbosity level: 0 no output, 1 info, 2 debug + :param seed: Seed for the pseudo random generators + :param device: Device (cpu, cuda, ...) on which the code should be run. + Setting it to auto, the code will be run on the GPU if possible. + :param _init_setup_model: Whether or not to build the network at the creation of the instance + """ + + def __init__( + self, + policy: Union[str, Type[ActorCriticPolicy]], + env: Union[GymEnv, str], + learning_rate: Union[float, Schedule] = 3e-4, + n_steps: int = 2048, + batch_size: Optional[int] = 64, + n_epochs: int = 10, + gamma: float = 0.99, + gae_lambda: float = 0.95, + ent_coef: float = 0.0, + vf_coef: float = 0.5, + max_grad_norm: float = 0.5, + use_sde: bool = False, + sde_sample_freq: int = -1, + target_kl: float = 0.01, + tensorboard_log: Optional[str] = None, + create_eval_env: bool = False, + policy_kwargs: Optional[Dict[str, Any]] = None, + verbose: int = 0, + seed: Optional[int] = None, + device: Union[th.device, str] = "auto", + _init_setup_model: bool = True, + ): + + super(TRPO, self).__init__( + policy, + env, + learning_rate=learning_rate, + n_steps=n_steps, + gamma=gamma, + gae_lambda=gae_lambda, + ent_coef=ent_coef, + vf_coef=vf_coef, + max_grad_norm=max_grad_norm, + use_sde=use_sde, + sde_sample_freq=sde_sample_freq, + tensorboard_log=tensorboard_log, + policy_kwargs=policy_kwargs, + verbose=verbose, + device=device, + create_eval_env=create_eval_env, + seed=seed, + _init_setup_model=False, + supported_action_spaces=( + spaces.Box, + spaces.Discrete, + spaces.MultiDiscrete, + spaces.MultiBinary, + ), + ) + + # Sanity check, otherwise it will lead to noisy gradient and NaN + # because of the advantage normalization + assert ( + batch_size > 1 + ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" + + if self.env is not None: + # Check that `n_steps * n_envs > 1` to avoid NaN + # when doing advantage normalization + buffer_size = self.env.num_envs * self.n_steps + assert ( + buffer_size > 1 + ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" + # Check that the rollout buffer size is a multiple of the mini-batch size + untruncated_batches = buffer_size // batch_size + if buffer_size % batch_size > 0: + warnings.warn( + f"You have specified a mini-batch size of {batch_size}," + f" but because the `RolloutBuffer` is of size `n_steps * n_envs = {buffer_size}`," + f" after every {untruncated_batches} untruncated mini-batches," + f" there will be a truncated mini-batch of size {buffer_size % batch_size}\n" + f"We recommend using a `batch_size` that is a factor of `n_steps * n_envs`.\n" + f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})" + ) + self.batch_size = batch_size + self.n_epochs = n_epochs + self.target_kl = target_kl + + if _init_setup_model: + self._setup_model() + + def train(self) -> None: + """ + Update policy using the currently gathered rollout buffer. + """ + # Update optimizer learning rate + self._update_learning_rate(self.policy.optimizer) + + po_values = [] + kl_divergences = [] + + continue_training = True + + # train for n_epochs epochs + for epoch in range(self.n_epochs): + # Do a complete pass on the rollout buffer + for rollout_data in self.rollout_buffer.get(self.batch_size): + actions = rollout_data.actions + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action from float to long + actions = rollout_data.actions.long().flatten() + + # Re-sample the noise matrix because the log_std has changed + # TODO: investigate why there is no issue with the gradient + # if that line is commented (as in SAC) + if self.use_sde: + self.policy.reset_noise(self.batch_size) + + values, log_prob, entropy = self.policy.evaluate_actions( + rollout_data.observations, actions + ) + values_pred = values.flatten() + + # ratio between old and new policy, should be one at the first iteration + ratio = th.exp(log_prob - rollout_data.old_log_prob) + + # surrogate policy objective + policy_obj = (values_pred.detach() * ratio).mean() + + # Logging + po_values.append(policy_obj.item()) + + # KL divergence + kl_div = F.kl_div( + log_prob, + rollout_data.old_log_prob, + log_target=True, + reduction="batchmean", + ) + + # Logging + kl_divergences.append(kl_div.item()) + + # Surrogate & KL gradient + self.policy.optimizer.zero_grad() + + # This is necessary because not all the parameters in the policy have gradients w.r.t. the KL divergence + g = [] + grad_kl = [] + grad_shape = [] + params = [] + for param in self.policy.parameters(): + kl_param_grad, *_ = torch.autograd.grad( + kl_div, + param, + create_graph=True, + retain_graph=True, + allow_unused=True, + only_inputs=True, + ) + if kl_param_grad is not None: + g_grad, *_ = torch.autograd.grad( + policy_obj, param, retain_graph=True, only_inputs=True + ) + + grad_shape.append(kl_param_grad.shape) + grad_kl.append(kl_param_grad.view(-1)) + g.append(g_grad.view(-1)) + params.append(param) + + g = torch.cat(g) + grad_kl = torch.cat(grad_kl) + + def Hpv(v, retain_graph=True): + jvp = (grad_kl * v).sum() + return flat_grad(jvp, params, retain_graph=retain_graph) + + s = cg_solver(Hpv, g) + + beta = 2 * self.target_kl + beta /= torch.matmul(s, Hpv(s, retain_graph=False)) + # TODO: investigate + # This assert shouldn't raise because s^T H s should not be negative + # Yet it does, it means Hpv is not returning H.v + # Could the code above do something wrong to the graph - making the Hessian vector product inaccurate? + assert beta >= 0 + beta = torch.sqrt(beta) + + # TODO: define a variable + alpha = 0.99 + orig_params = [param.detach().clone() for param in params] + + line_search_success = False + for i in range(10): + + j = 0 + for param, shape in zip(params, grad_shape): + k = param.numel() + param.data += alpha * beta * s[j:(j + k)].view(shape) + j += k + + with torch.no_grad(): + _, log_prob, _ = self.policy.evaluate_actions( + rollout_data.observations, actions + ) + kl_div = F.kl_div( + log_prob, + rollout_data.old_log_prob, + log_target=True, + reduction="batchmean", + ) + + if kl_div < self.target_kl: + line_search_success = True + break + + for param, orig_param in zip(params, orig_params): + param.data = orig_param.data.clone() + + alpha *= alpha + + if not continue_training: + break + + self._n_updates += self.n_epochs + explained_var = explained_variance( + self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten() + ) + + # Logs + # TODO: add extra logs + self.logger.record("train/policy_objective_value", np.mean(po_values)) + self.logger.record("train/kl_divergence_loss", np.mean(kl_divergences)) + self.logger.record("train/explained_variance", explained_var) + if hasattr(self.policy, "log_std"): + self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) + + self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + + def learn( + self, + total_timesteps: int, + callback: MaybeCallback = None, + log_interval: int = 1, + eval_env: Optional[GymEnv] = None, + eval_freq: int = -1, + n_eval_episodes: int = 5, + tb_log_name: str = "TRPO", + eval_log_path: Optional[str] = None, + reset_num_timesteps: bool = True, + ) -> OnPolicyAlgorithm: + + return super(TRPO, self).learn( + total_timesteps=total_timesteps, + callback=callback, + log_interval=log_interval, + eval_env=eval_env, + eval_freq=eval_freq, + n_eval_episodes=n_eval_episodes, + tb_log_name=tb_log_name, + eval_log_path=eval_log_path, + reset_num_timesteps=reset_num_timesteps, + ) From 98bc5b2e9d57a655601ed06b4680a04c6a1d36f4 Mon Sep 17 00:00:00 2001 From: Cyprien C Date: Mon, 9 Aug 2021 17:27:15 +0100 Subject: [PATCH 02/28] Feat: adding TRPO algorithm (WIP) Adding no_grad block for the line search Additional assert in the conjugate solver to help debugging --- sb3_contrib/common/utils.py | 5 +++- sb3_contrib/trpo/trpo.py | 56 ++++++++++++++++++------------------- 2 files changed, 32 insertions(+), 29 deletions(-) diff --git a/sb3_contrib/common/utils.py b/sb3_contrib/common/utils.py index 44892863..bff68707 100644 --- a/sb3_contrib/common/utils.py +++ b/sb3_contrib/common/utils.py @@ -96,7 +96,10 @@ def cg_solver(Avp_fun: Callable[[th.Tensor], th.Tensor], b, max_iter=15) -> th.T Avp = Avp_fun(p) r_dot = th.matmul(r, r) - alpha = r_dot / th.matmul(p, Avp) + pAp = th.matmul(p, Avp) + # This shouldn't raise if the matrix in the matrix in Avp_fun is positive-definite + assert pAp >= 0 + alpha = r_dot / pAp x += alpha * p if i == max_iter - 1: diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index a6a6c761..62af0f99 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -225,7 +225,7 @@ def train(self) -> None: def Hpv(v, retain_graph=True): jvp = (grad_kl * v).sum() - return flat_grad(jvp, params, retain_graph=retain_graph) + return flat_grad(jvp, params, retain_graph=retain_graph).detach() s = cg_solver(Hpv, g) @@ -243,33 +243,33 @@ def Hpv(v, retain_graph=True): orig_params = [param.detach().clone() for param in params] line_search_success = False - for i in range(10): - - j = 0 - for param, shape in zip(params, grad_shape): - k = param.numel() - param.data += alpha * beta * s[j:(j + k)].view(shape) - j += k - - with torch.no_grad(): - _, log_prob, _ = self.policy.evaluate_actions( - rollout_data.observations, actions - ) - kl_div = F.kl_div( - log_prob, - rollout_data.old_log_prob, - log_target=True, - reduction="batchmean", - ) - - if kl_div < self.target_kl: - line_search_success = True - break - - for param, orig_param in zip(params, orig_params): - param.data = orig_param.data.clone() - - alpha *= alpha + with torch.no_grad(): + for i in range(10): + + j = 0 + for param, shape in zip(params, grad_shape): + k = param.numel() + param.data += alpha * beta * s[j:(j + k)].view(shape) + j += k + + _, log_prob, _ = self.policy.evaluate_actions( + rollout_data.observations, actions + ) + kl_div = F.kl_div( + log_prob, + rollout_data.old_log_prob, + log_target=True, + reduction="batchmean", + ) + + if kl_div < self.target_kl: + line_search_success = True + break + + for param, orig_param in zip(params, orig_params): + param.data = orig_param.data.clone() + + alpha *= alpha if not continue_training: break From 97ece67cc9b8322a801c39eabcb7e0915c70604a Mon Sep 17 00:00:00 2001 From: Cyprien C Date: Tue, 17 Aug 2021 09:02:16 +0100 Subject: [PATCH 03/28] Feat: adding TRPO algorithm (WIP) - Adding ActorCriticPolicy.get_distribution - Using the Distribution object to compute the KL divergence - Checking for objective improvement in the line search - Moving magic numbers to instance variables --- sb3_contrib/common/policies.py | 961 +++++++++++++++++++++++++++++++++ sb3_contrib/common/utils.py | 4 +- sb3_contrib/trpo/trpo.py | 104 ++-- 3 files changed, 1026 insertions(+), 43 deletions(-) create mode 100644 sb3_contrib/common/policies.py diff --git a/sb3_contrib/common/policies.py b/sb3_contrib/common/policies.py new file mode 100644 index 00000000..6a884797 --- /dev/null +++ b/sb3_contrib/common/policies.py @@ -0,0 +1,961 @@ +"""Policies: abstract base class and concrete implementations.""" + +import collections +import copy +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Dict, List, Optional, Tuple, Type, Union + +import gym +import numpy as np +import torch as th +from torch import nn + +from stable_baselines3.common.distributions import ( + BernoulliDistribution, + CategoricalDistribution, + DiagGaussianDistribution, + Distribution, + MultiCategoricalDistribution, + StateDependentNoiseDistribution, + make_proba_distribution, +) +from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, maybe_transpose, preprocess_obs +from stable_baselines3.common.torch_layers import ( + BaseFeaturesExtractor, + CombinedExtractor, + FlattenExtractor, + MlpExtractor, + NatureCNN, + create_mlp, +) +from stable_baselines3.common.type_aliases import Schedule +from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor + + +class BaseModel(nn.Module, ABC): + """ + The base model object: makes predictions in response to observations. + + In the case of policies, the prediction is an action. In the case of critics, it is the + estimated value of the observation. + + :param observation_space: The observation space of the environment + :param action_space: The action space of the environment + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param features_extractor: Network to extract features + (a CNN when using images, a nn.Flatten() layer otherwise) + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + features_extractor: Optional[nn.Module] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super(BaseModel, self).__init__() + + if optimizer_kwargs is None: + optimizer_kwargs = {} + + if features_extractor_kwargs is None: + features_extractor_kwargs = {} + + self.observation_space = observation_space + self.action_space = action_space + self.features_extractor = features_extractor + self.normalize_images = normalize_images + + self.optimizer_class = optimizer_class + self.optimizer_kwargs = optimizer_kwargs + self.optimizer = None # type: Optional[th.optim.Optimizer] + + self.features_extractor_class = features_extractor_class + self.features_extractor_kwargs = features_extractor_kwargs + + @abstractmethod + def forward(self, *args, **kwargs): + pass + + def _update_features_extractor( + self, + net_kwargs: Dict[str, Any], + features_extractor: Optional[BaseFeaturesExtractor] = None, + ) -> Dict[str, Any]: + """ + Update the network keyword arguments and create a new features extractor object if needed. + If a ``features_extractor`` object is passed, then it will be shared. + + :param net_kwargs: the base network keyword arguments, without the ones + related to features extractor + :param features_extractor: a features extractor object. + If None, a new object will be created. + :return: The updated keyword arguments + """ + net_kwargs = net_kwargs.copy() + if features_extractor is None: + # The features extractor is not shared, create a new one + features_extractor = self.make_features_extractor() + net_kwargs.update(dict(features_extractor=features_extractor, features_dim=features_extractor.features_dim)) + return net_kwargs + + def make_features_extractor(self) -> BaseFeaturesExtractor: + """Helper method to create a features extractor.""" + return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs) + + def extract_features(self, obs: th.Tensor) -> th.Tensor: + """ + Preprocess the observation if needed and extract features. + + :param obs: + :return: + """ + assert self.features_extractor is not None, "No features extractor was set" + preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) + return self.features_extractor(preprocessed_obs) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + """ + Get data that need to be saved in order to re-create the model when loading it from disk. + + :return: The dictionary to pass to the as kwargs constructor when reconstruction this model. + """ + return dict( + observation_space=self.observation_space, + action_space=self.action_space, + # Passed to the constructor by child class + # squash_output=self.squash_output, + # features_extractor=self.features_extractor + normalize_images=self.normalize_images, + ) + + @property + def device(self) -> th.device: + """Infer which device this policy lives on by inspecting its parameters. + If it has no parameters, the 'cpu' device is used as a fallback. + + :return:""" + for param in self.parameters(): + return param.device + return get_device("cpu") + + def save(self, path: str) -> None: + """ + Save model to a given location. + + :param path: + """ + th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) + + @classmethod + def load(cls, path: str, device: Union[th.device, str] = "auto") -> "BaseModel": + """ + Load model from path. + + :param path: + :param device: Device on which the policy should be loaded. + :return: + """ + device = get_device(device) + saved_variables = th.load(path, map_location=device) + # Create policy object + model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable + # Load weights + model.load_state_dict(saved_variables["state_dict"]) + model.to(device) + return model + + def load_from_vector(self, vector: np.ndarray) -> None: + """ + Load parameters from a 1D vector. + + :param vector: + """ + th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters()) + + def parameters_to_vector(self) -> np.ndarray: + """ + Convert the parameters to a 1D vector. + + :return: + """ + return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy() + + +class BasePolicy(BaseModel): + """The base policy object. + + Parameters are mostly the same as `BaseModel`; additions are documented below. + + :param args: positional arguments passed through to `BaseModel`. + :param kwargs: keyword arguments passed through to `BaseModel`. + :param squash_output: For continuous actions, whether the output is squashed + or not using a ``tanh()`` function. + """ + + def __init__(self, *args, squash_output: bool = False, **kwargs): + super(BasePolicy, self).__init__(*args, **kwargs) + self._squash_output = squash_output + + @staticmethod + def _dummy_schedule(progress_remaining: float) -> float: + """(float) Useful for pickling policy.""" + del progress_remaining + return 0.0 + + @property + def squash_output(self) -> bool: + """(bool) Getter for squash_output.""" + return self._squash_output + + @staticmethod + def init_weights(module: nn.Module, gain: float = 1) -> None: + """ + Orthogonal initialization (used in PPO and A2C) + """ + if isinstance(module, (nn.Linear, nn.Conv2d)): + nn.init.orthogonal_(module.weight, gain=gain) + if module.bias is not None: + module.bias.data.fill_(0.0) + + @abstractmethod + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + """ + Get the action according to the policy for a given observation. + + By default provides a dummy implementation -- not all BasePolicy classes + implement this, e.g. if they are a Critic in an Actor-Critic method. + + :param observation: + :param deterministic: Whether to use stochastic or deterministic actions + :return: Taken action according to the policy + """ + + def predict( + self, + observation: Union[np.ndarray, Dict[str, np.ndarray]], + state: Optional[np.ndarray] = None, + mask: Optional[np.ndarray] = None, + deterministic: bool = False, + ) -> Tuple[np.ndarray, Optional[np.ndarray]]: + """ + Get the policy action and state from an observation (and optional state). + Includes sugar-coating to handle different observations (e.g. normalizing images). + + :param observation: the input observation + :param state: The last states (can be None, used in recurrent policies) + :param mask: The last masks (can be None, used in recurrent policies) + :param deterministic: Whether or not to return deterministic actions. + :return: the model's action and the next state + (used in recurrent policies) + """ + # TODO (GH/1): add support for RNN policies + # if state is None: + # state = self.initial_state + # if mask is None: + # mask = [False for _ in range(self.n_envs)] + + vectorized_env = False + if isinstance(observation, dict): + # need to copy the dict as the dict in VecFrameStack will become a torch tensor + observation = copy.deepcopy(observation) + for key, obs in observation.items(): + obs_space = self.observation_space.spaces[key] + if is_image_space(obs_space): + obs_ = maybe_transpose(obs, obs_space) + else: + obs_ = np.array(obs) + vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space) + # Add batch dimension if needed + observation[key] = obs_.reshape((-1,) + self.observation_space[key].shape) + + elif is_image_space(self.observation_space): + # Handle the different cases for images + # as PyTorch use channel first format + observation = maybe_transpose(observation, self.observation_space) + + else: + observation = np.array(observation) + + if not isinstance(observation, dict): + # Dict obs need to be handled separately + vectorized_env = is_vectorized_observation(observation, self.observation_space) + # Add batch dimension if needed + observation = observation.reshape((-1,) + self.observation_space.shape) + + observation = obs_as_tensor(observation, self.device) + + with th.no_grad(): + actions = self._predict(observation, deterministic=deterministic) + # Convert to numpy + actions = actions.cpu().numpy() + + if isinstance(self.action_space, gym.spaces.Box): + if self.squash_output: + # Rescale to proper domain when using squashing + actions = self.unscale_action(actions) + else: + # Actions could be on arbitrary scale, so clip the actions to avoid + # out of bound error (e.g. if sampling from a Gaussian distribution) + actions = np.clip(actions, self.action_space.low, self.action_space.high) + + if not vectorized_env: + if state is not None: + raise ValueError("Error: The environment must be vectorized when using recurrent policies.") + actions = actions[0] + + return actions, state + + def scale_action(self, action: np.ndarray) -> np.ndarray: + """ + Rescale the action from [low, high] to [-1, 1] + (no need for symmetric action space) + + :param action: Action to scale + :return: Scaled action + """ + low, high = self.action_space.low, self.action_space.high + return 2.0 * ((action - low) / (high - low)) - 1.0 + + def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray: + """ + Rescale the action from [-1, 1] to [low, high] + (no need for symmetric action space) + + :param scaled_action: Action to un-scale + """ + low, high = self.action_space.low, self.action_space.high + return low + (0.5 * (scaled_action + 1.0) * (high - low)) + + +class ActorCriticPolicy(BasePolicy): + """ + Policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + + if optimizer_kwargs is None: + optimizer_kwargs = {} + # Small values to avoid NaN in Adam optimizer + if optimizer_class == th.optim.Adam: + optimizer_kwargs["eps"] = 1e-5 + + super(ActorCriticPolicy, self).__init__( + observation_space, + action_space, + features_extractor_class, + features_extractor_kwargs, + optimizer_class=optimizer_class, + optimizer_kwargs=optimizer_kwargs, + squash_output=squash_output, + ) + + # Default network architecture, from stable-baselines + if net_arch is None: + if features_extractor_class == NatureCNN: + net_arch = [] + else: + net_arch = [dict(pi=[64, 64], vf=[64, 64])] + + self.net_arch = net_arch + self.activation_fn = activation_fn + self.ortho_init = ortho_init + + self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) + self.features_dim = self.features_extractor.features_dim + + self.normalize_images = normalize_images + self.log_std_init = log_std_init + dist_kwargs = None + # Keyword arguments for gSDE distribution + if use_sde: + dist_kwargs = { + "full_std": full_std, + "squash_output": squash_output, + "use_expln": use_expln, + "learn_features": sde_net_arch is not None, + } + + self.sde_features_extractor = None + self.sde_net_arch = sde_net_arch + self.use_sde = use_sde + self.dist_kwargs = dist_kwargs + + # Action distribution + self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs) + + self._build(lr_schedule) + + def _get_constructor_parameters(self) -> Dict[str, Any]: + data = super()._get_constructor_parameters() + + default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) + + data.update( + dict( + net_arch=self.net_arch, + activation_fn=self.activation_fn, + use_sde=self.use_sde, + log_std_init=self.log_std_init, + squash_output=default_none_kwargs["squash_output"], + full_std=default_none_kwargs["full_std"], + sde_net_arch=self.sde_net_arch, + use_expln=default_none_kwargs["use_expln"], + lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone + ortho_init=self.ortho_init, + optimizer_class=self.optimizer_class, + optimizer_kwargs=self.optimizer_kwargs, + features_extractor_class=self.features_extractor_class, + features_extractor_kwargs=self.features_extractor_kwargs, + ) + ) + return data + + def reset_noise(self, n_envs: int = 1) -> None: + """ + Sample new weights for the exploration matrix. + + :param n_envs: + """ + assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE" + self.action_dist.sample_weights(self.log_std, batch_size=n_envs) + + def _build_mlp_extractor(self) -> None: + """ + Create the policy and value networks. + Part of the layers can be shared. + """ + # Note: If net_arch is None and some features extractor is used, + # net_arch here is an empty list and mlp_extractor does not + # really contain any layers (acts like an identity module). + self.mlp_extractor = MlpExtractor( + self.features_dim, + net_arch=self.net_arch, + activation_fn=self.activation_fn, + device=self.device, + ) + + def _build(self, lr_schedule: Schedule) -> None: + """ + Create the networks and the optimizer. + + :param lr_schedule: Learning rate schedule + lr_schedule(1) is the initial learning rate + """ + self._build_mlp_extractor() + + latent_dim_pi = self.mlp_extractor.latent_dim_pi + + # Separate features extractor for gSDE + if self.sde_net_arch is not None: + self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor( + self.features_dim, self.sde_net_arch, self.activation_fn + ) + + if isinstance(self.action_dist, DiagGaussianDistribution): + self.action_net, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi, log_std_init=self.log_std_init + ) + elif isinstance(self.action_dist, StateDependentNoiseDistribution): + latent_sde_dim = latent_dim_pi if self.sde_net_arch is None else latent_sde_dim + self.action_net, self.log_std = self.action_dist.proba_distribution_net( + latent_dim=latent_dim_pi, latent_sde_dim=latent_sde_dim, log_std_init=self.log_std_init + ) + elif isinstance(self.action_dist, CategoricalDistribution): + self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + elif isinstance(self.action_dist, MultiCategoricalDistribution): + self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + elif isinstance(self.action_dist, BernoulliDistribution): + self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) + else: + raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.") + + self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1) + # Init weights: use orthogonal initialization + # with small initial weight for the output + if self.ortho_init: + # TODO: check for features_extractor + # Values from stable-baselines. + # features_extractor/mlp values are + # originally from openai/baselines (default gains/init_scales). + module_gains = { + self.features_extractor: np.sqrt(2), + self.mlp_extractor: np.sqrt(2), + self.action_net: 0.01, + self.value_net: 1, + } + for module, gain in module_gains.items(): + module.apply(partial(self.init_weights, gain=gain)) + + # Setup optimizer with initial learning rate + self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) + + def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Forward pass in all the networks (actor and critic) + + :param obs: Observation + :param deterministic: Whether to sample or use deterministic actions + :return: action, value and log probability of the action + """ + latent_pi, latent_vf, latent_sde = self._get_latent(obs) + # Evaluate the values for the given observations + values = self.value_net(latent_vf) + distribution = self._get_action_dist_from_latent(latent_pi, latent_sde=latent_sde) + actions = distribution.get_actions(deterministic=deterministic) + log_prob = distribution.log_prob(actions) + return actions, values, log_prob + + def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Get the latent code (i.e., activations of the last layer of each network) + for the different networks. + + :param obs: Observation + :return: Latent codes + for the actor, the value function and for gSDE function + """ + # Preprocess the observation if needed + features = self.extract_features(obs) + latent_pi, latent_vf = self.mlp_extractor(features) + + # Features for sde + latent_sde = latent_pi + if self.sde_features_extractor is not None: + latent_sde = self.sde_features_extractor(features) + return latent_pi, latent_vf, latent_sde + + def _get_action_dist_from_latent(self, latent_pi: th.Tensor, latent_sde: Optional[th.Tensor] = None) -> Distribution: + """ + Retrieve action distribution given the latent codes. + + :param latent_pi: Latent code for the actor + :param latent_sde: Latent code for the gSDE exploration function + :return: Action distribution + """ + mean_actions = self.action_net(latent_pi) + + if isinstance(self.action_dist, DiagGaussianDistribution): + return self.action_dist.proba_distribution(mean_actions, self.log_std) + elif isinstance(self.action_dist, CategoricalDistribution): + # Here mean_actions are the logits before the softmax + return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, MultiCategoricalDistribution): + # Here mean_actions are the flattened logits + return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, BernoulliDistribution): + # Here mean_actions are the logits (before rounding to get the binary actions) + return self.action_dist.proba_distribution(action_logits=mean_actions) + elif isinstance(self.action_dist, StateDependentNoiseDistribution): + return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde) + else: + raise ValueError("Invalid action distribution") + + def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: + """ + Get the action according to the policy for a given observation. + + :param observation: + :param deterministic: Whether to use stochastic or deterministic actions + :return: Taken action according to the policy + """ + latent_pi, _, latent_sde = self._get_latent(observation) + distribution = self._get_action_dist_from_latent(latent_pi, latent_sde) + return distribution.get_actions(deterministic=deterministic) + + def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: + """ + Evaluate actions according to the current policy, + given the observations. + + :param obs: + :param actions: + :return: estimated value, log likelihood of taking those actions + and entropy of the action distribution. + """ + latent_pi, latent_vf, latent_sde = self._get_latent(obs) + distribution = self._get_action_dist_from_latent(latent_pi, latent_sde) + log_prob = distribution.log_prob(actions) + values = self.value_net(latent_vf) + return values, log_prob, distribution.entropy() + + def get_distribution(self) -> Distribution: + """ + Get the current action distribution + :return: Action distribution + """ + return self.action_dist + + +class ActorCriticCnnPolicy(ActorCriticPolicy): + """ + CNN policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Features extractor to use. + :param features_extractor_kwargs: Keyword arguments + to pass to the features extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super(ActorCriticCnnPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + sde_net_arch, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + +class MultiInputActorCriticPolicy(ActorCriticPolicy): + """ + MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). + Used by A2C, PPO and the likes. + + :param observation_space: Observation space (Tuple) + :param action_space: Action space + :param lr_schedule: Learning rate schedule (could be constant) + :param net_arch: The specification of the policy and value networks. + :param activation_fn: Activation function + :param ortho_init: Whether to use or not orthogonal initialization + :param use_sde: Whether to use State Dependent Exploration or not + :param log_std_init: Initial value for the log standard deviation + :param full_std: Whether to use (n_features x n_actions) parameters + for the std instead of only (n_features,) when using gSDE + :param sde_net_arch: Network architecture for extracting features + when using gSDE. If None, the latent features from the policy will be used. + Pass an empty list to use the states as features. + :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure + a positive standard deviation (cf paper). It allows to keep variance + above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. + :param squash_output: Whether to squash the output using a tanh function, + this allows to ensure boundaries when using gSDE. + :param features_extractor_class: Uses the CombinedExtractor + :param features_extractor_kwargs: Keyword arguments + to pass to the feature extractor. + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param optimizer_class: The optimizer to use, + ``th.optim.Adam`` by default + :param optimizer_kwargs: Additional keyword arguments, + excluding the learning rate, to pass to the optimizer + """ + + def __init__( + self, + observation_space: gym.spaces.Dict, + action_space: gym.spaces.Space, + lr_schedule: Schedule, + net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, + activation_fn: Type[nn.Module] = nn.Tanh, + ortho_init: bool = True, + use_sde: bool = False, + log_std_init: float = 0.0, + full_std: bool = True, + sde_net_arch: Optional[List[int]] = None, + use_expln: bool = False, + squash_output: bool = False, + features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, + features_extractor_kwargs: Optional[Dict[str, Any]] = None, + normalize_images: bool = True, + optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + ): + super(MultiInputActorCriticPolicy, self).__init__( + observation_space, + action_space, + lr_schedule, + net_arch, + activation_fn, + ortho_init, + use_sde, + log_std_init, + full_std, + sde_net_arch, + use_expln, + squash_output, + features_extractor_class, + features_extractor_kwargs, + normalize_images, + optimizer_class, + optimizer_kwargs, + ) + + +class ContinuousCritic(BaseModel): + """ + Critic network(s) for DDPG/SAC/TD3. + It represents the action-state value function (Q-value function). + Compared to A2C/PPO critics, this one represents the Q-value + and takes the continuous action as input. It is concatenated with the state + and then fed to the network which outputs a single value: Q(s, a). + For more recent algorithms like SAC/TD3, multiple networks + are created to give different estimates. + + By default, it creates two critic networks used to reduce overestimation + thanks to clipped Q-learning (cf TD3 paper). + + :param observation_space: Obervation space + :param action_space: Action space + :param net_arch: Network architecture + :param features_extractor: Network to extract features + (a CNN when using images, a nn.Flatten() layer otherwise) + :param features_dim: Number of features + :param activation_fn: Activation function + :param normalize_images: Whether to normalize images or not, + dividing by 255.0 (True by default) + :param n_critics: Number of critic networks to create. + :param share_features_extractor: Whether the features extractor is shared or not + between the actor and the critic (this saves computation time) + """ + + def __init__( + self, + observation_space: gym.spaces.Space, + action_space: gym.spaces.Space, + net_arch: List[int], + features_extractor: nn.Module, + features_dim: int, + activation_fn: Type[nn.Module] = nn.ReLU, + normalize_images: bool = True, + n_critics: int = 2, + share_features_extractor: bool = True, + ): + super().__init__( + observation_space, + action_space, + features_extractor=features_extractor, + normalize_images=normalize_images, + ) + + action_dim = get_action_dim(self.action_space) + + self.share_features_extractor = share_features_extractor + self.n_critics = n_critics + self.q_networks = [] + for idx in range(n_critics): + q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) + q_net = nn.Sequential(*q_net) + self.add_module(f"qf{idx}", q_net) + self.q_networks.append(q_net) + + def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]: + # Learn the features extractor using the policy loss only + # when the features_extractor is shared with the actor + with th.set_grad_enabled(not self.share_features_extractor): + features = self.extract_features(obs) + qvalue_input = th.cat([features, actions], dim=1) + return tuple(q_net(qvalue_input) for q_net in self.q_networks) + + def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor: + """ + Only predict the Q-value using the first network. + This allows to reduce computation when all the estimates are not needed + (e.g. when updating the policy in TD3). + """ + with th.no_grad(): + features = self.extract_features(obs) + return self.q_networks[0](th.cat([features, actions], dim=1)) + + +def create_sde_features_extractor( + features_dim: int, sde_net_arch: List[int], activation_fn: Type[nn.Module] +) -> Tuple[nn.Sequential, int]: + """ + Create the neural network that will be used to extract features + for the gSDE exploration function. + + :param features_dim: + :param sde_net_arch: + :param activation_fn: + :return: + """ + # Special case: when using states as features (i.e. sde_net_arch is an empty list) + # don't use any activation function + sde_activation = activation_fn if len(sde_net_arch) > 0 else None + latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_output=False) + latent_sde_dim = sde_net_arch[-1] if len(sde_net_arch) > 0 else features_dim + sde_features_extractor = nn.Sequential(*latent_sde_net) + return sde_features_extractor, latent_sde_dim + + +_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]] + + +def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[BasePolicy]: + """ + Returns the registered policy from the base type and name. + See `register_policy` for registering policies and explanation. + + :param base_policy_type: the base policy class + :param name: the policy name + :return: the policy + """ + if base_policy_type not in _policy_registry: + raise KeyError(f"Error: the policy type {base_policy_type} is not registered!") + if name not in _policy_registry[base_policy_type]: + raise KeyError( + f"Error: unknown policy type {name}," + f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!" + ) + return _policy_registry[base_policy_type][name] + + +def register_policy(name: str, policy: Type[BasePolicy]) -> None: + """ + Register a policy, so it can be called using its name. + e.g. SAC('MlpPolicy', ...) instead of SAC(MlpPolicy, ...). + + The goal here is to standardize policy naming, e.g. + all algorithms can call upon "MlpPolicy" or "CnnPolicy", + and they receive respective policies that work for them. + Consider following: + + OnlinePolicy + -- OnlineMlpPolicy ("MlpPolicy") + -- OnlineCnnPolicy ("CnnPolicy") + OfflinePolicy + -- OfflineMlpPolicy ("MlpPolicy") + -- OfflineCnnPolicy ("CnnPolicy") + + Two policies have name "MlpPolicy" and two have "CnnPolicy". + In `get_policy_from_name`, the parent class (e.g. OnlinePolicy) + is given and used to select and return the correct policy. + + :param name: the policy name + :param policy: the policy class + """ + sub_class = None + for cls in BasePolicy.__subclasses__(): + if issubclass(policy, cls): + sub_class = cls + break + if sub_class is None: + raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!") + + if sub_class not in _policy_registry: + _policy_registry[sub_class] = {} + if name in _policy_registry[sub_class]: + # Check if the registered policy is same + # we try to register. If not so, + # do not override and complain. + if _policy_registry[sub_class][name] != policy: + raise ValueError(f"Error: the name {name} is already registered for a different policy, will not override.") + _policy_registry[sub_class][name] = policy diff --git a/sb3_contrib/common/utils.py b/sb3_contrib/common/utils.py index bff68707..a7fa6035 100644 --- a/sb3_contrib/common/utils.py +++ b/sb3_contrib/common/utils.py @@ -71,7 +71,7 @@ def quantile_huber_loss( # TODO: write regression tests -def cg_solver(Avp_fun: Callable[[th.Tensor], th.Tensor], b, max_iter=15) -> th.Tensor: +def cg_solver(Avp_fun: Callable[[th.Tensor], th.Tensor], b, max_iter=10) -> th.Tensor: """ Finds an approximate solution to a set of linear equations Ax = b @@ -97,8 +97,6 @@ def cg_solver(Avp_fun: Callable[[th.Tensor], th.Tensor], b, max_iter=15) -> th.T r_dot = th.matmul(r, r) pAp = th.matmul(p, Avp) - # This shouldn't raise if the matrix in the matrix in Avp_fun is positive-definite - assert pAp >= 0 alpha = r_dot / pAp x += alpha * p diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index 62af0f99..a89b62e3 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -1,3 +1,4 @@ +import copy import warnings from typing import Any, Dict, Optional, Type, Union @@ -5,11 +6,11 @@ import torch import torch as th from gym import spaces -from torch.nn import functional as F +from torch.distributions import kl_divergence +from sb3_contrib.common.policies import ActorCriticPolicy from sb3_contrib.common.utils import flat_grad, cg_solver from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm -from stable_baselines3.common.policies import ActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance @@ -35,6 +36,11 @@ class TRPO(OnPolicyAlgorithm): :param batch_size: Minibatch size :param n_epochs: Number of epoch when optimizing the surrogate loss :param gamma: Discount factor + :param cg_max_steps: maximum number of steps in the Conjugate Gradient algoritgm + for computing the Hessian vector product + :param cg_damping: damping in the Hessian vector product computation + :param ls_alpha: step-size reduction factor for the line-search (i.e. theta_new = theta + alpha^i * step) + :param ls_steps: maximum number of steps in the line-search :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator :param ent_coef: Entropy coefficient for the loss calculation :param vf_coef: Value function coefficient for the loss calculation @@ -67,6 +73,10 @@ def __init__( batch_size: Optional[int] = 64, n_epochs: int = 10, gamma: float = 0.99, + cg_max_steps: int = 10, + cg_damping: float = 0.1, + ls_alpha: float = 0.99, + ls_steps: int = 10, gae_lambda: float = 0.95, ent_coef: float = 0.0, vf_coef: float = 0.5, @@ -136,6 +146,10 @@ def __init__( ) self.batch_size = batch_size self.n_epochs = n_epochs + self.cg_max_steps = cg_max_steps + self.cg_damping = cg_damping + self.ls_alpha = ls_alpha + self.ls_steps = ls_steps self.target_kl = target_kl if _init_setup_model: @@ -150,6 +164,7 @@ def train(self) -> None: po_values = [] kl_divergences = [] + line_search_results = [] continue_training = True @@ -168,30 +183,30 @@ def train(self) -> None: if self.use_sde: self.policy.reset_noise(self.batch_size) - values, log_prob, entropy = self.policy.evaluate_actions( + with torch.no_grad(): + _ = self.policy.evaluate_actions(rollout_data.observations, actions) + old_distribution = copy.copy(self.policy.get_distribution()) + + values, log_prob, _ = self.policy.evaluate_actions( rollout_data.observations, actions ) - values_pred = values.flatten() + distribution = self.policy.get_distribution() + + advantages = rollout_data.advantages + advantages = (advantages - advantages.mean()) / ( + rollout_data.advantages.std() + 1e-8 + ) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) # surrogate policy objective - policy_obj = (values_pred.detach() * ratio).mean() - - # Logging - po_values.append(policy_obj.item()) + policy_obj = (advantages * ratio).mean() # KL divergence - kl_div = F.kl_div( - log_prob, - rollout_data.old_log_prob, - log_target=True, - reduction="batchmean", - ) - - # Logging - kl_divergences.append(kl_div.item()) + kl_div = kl_divergence( + distribution.distribution, old_distribution.distribution + ).mean() # Surrogate & KL gradient self.policy.optimizer.zero_grad() @@ -225,44 +240,39 @@ def train(self) -> None: def Hpv(v, retain_graph=True): jvp = (grad_kl * v).sum() - return flat_grad(jvp, params, retain_graph=retain_graph).detach() + return flat_grad(jvp, params, retain_graph=retain_graph) + self.cg_damping * v - s = cg_solver(Hpv, g) + s = cg_solver(Hpv, g, max_iter=self.cg_max_steps) beta = 2 * self.target_kl beta /= torch.matmul(s, Hpv(s, retain_graph=False)) - # TODO: investigate - # This assert shouldn't raise because s^T H s should not be negative - # Yet it does, it means Hpv is not returning H.v - # Could the code above do something wrong to the graph - making the Hessian vector product inaccurate? - assert beta >= 0 beta = torch.sqrt(beta) - # TODO: define a variable - alpha = 0.99 + alpha = self.ls_alpha orig_params = [param.detach().clone() for param in params] line_search_success = False with torch.no_grad(): - for i in range(10): + for i in range(self.ls_steps): j = 0 for param, shape in zip(params, grad_shape): k = param.numel() - param.data += alpha * beta * s[j:(j + k)].view(shape) + param.data += alpha * beta * s[j : (j + k)].view(shape) j += k - _, log_prob, _ = self.policy.evaluate_actions( - rollout_data.observations, actions - ) - kl_div = F.kl_div( - log_prob, - rollout_data.old_log_prob, - log_target=True, - reduction="batchmean", - ) - - if kl_div < self.target_kl: + values, log_prob, _ = self.policy.evaluate_actions( + rollout_data.observations, actions + ) + + ratio = th.exp(log_prob - rollout_data.old_log_prob) + new_policy_obj = (advantages * ratio).mean() + + kl_div = kl_divergence( + distribution.distribution, old_distribution.distribution + ).mean() + + if (kl_div < self.target_kl) and (new_policy_obj > policy_obj): line_search_success = True break @@ -271,6 +281,20 @@ def Hpv(v, retain_graph=True): alpha *= alpha + line_search_results.append(line_search_success) + + if not line_search_success: + for param, orig_param in zip(params, orig_params): + param.data = orig_param.data.clone() + + po_values.append(policy_obj.item()) + kl_divergences.append(0) + else: + po_values.append(new_policy_obj.item()) + kl_divergences.append(kl_div.item()) + + # TODO: Critic training? + if not continue_training: break @@ -280,10 +304,10 @@ def Hpv(v, retain_graph=True): ) # Logs - # TODO: add extra logs self.logger.record("train/policy_objective_value", np.mean(po_values)) self.logger.record("train/kl_divergence_loss", np.mean(kl_divergences)) self.logger.record("train/explained_variance", explained_var) + self.logger.record("train/line_search_success", np.mean(line_search_results)) if hasattr(self.policy, "log_std"): self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) From 799b14045b39bf6dfd025c2428a657556aaa5149 Mon Sep 17 00:00:00 2001 From: Cyprien C Date: Tue, 17 Aug 2021 18:37:30 +0100 Subject: [PATCH 04/28] Feat: adding TRPO algorithm (WIP) Improving numerical stability of the conjugate gradient algorithm Critic updates --- sb3_contrib/common/utils.py | 35 +++++++++++++++++++++++++---------- sb3_contrib/trpo/trpo.py | 33 +++++++++++++++++++++++++++++---- 2 files changed, 54 insertions(+), 14 deletions(-) diff --git a/sb3_contrib/common/utils.py b/sb3_contrib/common/utils.py index a7fa6035..89a1963d 100644 --- a/sb3_contrib/common/utils.py +++ b/sb3_contrib/common/utils.py @@ -71,7 +71,7 @@ def quantile_huber_loss( # TODO: write regression tests -def cg_solver(Avp_fun: Callable[[th.Tensor], th.Tensor], b, max_iter=10) -> th.Tensor: +def cg_solver(Avp_fun: Callable[[th.Tensor], th.Tensor], b, max_iter=10, residual_tol=1e-10) -> th.Tensor: """ Finds an approximate solution to a set of linear equations Ax = b @@ -83,29 +83,44 @@ def cg_solver(Avp_fun: Callable[[th.Tensor], th.Tensor], b, max_iter=10) -> th.T the right hand term in the set of linear equations Ax = b :param max_iter : int the maximum number of iterations (default is 10) + :param residual_tol: float + residual tolerance for early stopping of the solving (default is 1e-10) :return x : torch.FloatTensor the approximate solution to the system of equations defined by Avp_fun and b """ - x = th.zeros_like(b) - r = b.clone() - p = b.clone() + # The vector is not initialized at 0 because of the instability issues when the gradient becomes small. + # A small random gaussian noise is used for the initialization. + x = 1e-4 * th.randn_like(b) + r = b - Avp_fun(x) + r_dot = th.matmul(r, r) + + if r_dot < residual_tol: + # If the gradient becomes extremely small + # The denominator in alpha will become zero + # Leading to a division by zero + return x + + p = r.clone() for i in range(max_iter): Avp = Avp_fun(p) - r_dot = th.matmul(r, r) - pAp = th.matmul(p, Avp) - alpha = r_dot / pAp + alpha = r_dot / p.dot(Avp) x += alpha * p if i == max_iter - 1: return x - r_new = r - alpha * Avp - beta = th.matmul(r_new, r_new) / r_dot - r = r_new + r -= alpha * Avp + new_r_dot = th.matmul(r, r) + + if new_r_dot < residual_tol: + return x + + beta = new_r_dot / r_dot + r_dot = new_r_dot p = r + beta * p diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index a89b62e3..5ae8188f 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -7,6 +7,7 @@ import torch as th from gym import spaces from torch.distributions import kl_divergence +from torch.nn import functional as F from sb3_contrib.common.policies import ActorCriticPolicy from sb3_contrib.common.utils import flat_grad, cg_solver @@ -41,6 +42,7 @@ class TRPO(OnPolicyAlgorithm): :param cg_damping: damping in the Hessian vector product computation :param ls_alpha: step-size reduction factor for the line-search (i.e. theta_new = theta + alpha^i * step) :param ls_steps: maximum number of steps in the line-search + :param n_critic_updates: number of critic updates per policy updates :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator :param ent_coef: Entropy coefficient for the loss calculation :param vf_coef: Value function coefficient for the loss calculation @@ -74,9 +76,10 @@ def __init__( n_epochs: int = 10, gamma: float = 0.99, cg_max_steps: int = 10, - cg_damping: float = 0.1, + cg_damping: float = 1e-3, ls_alpha: float = 0.99, ls_steps: int = 10, + n_critic_updates: int = 5, gae_lambda: float = 0.95, ent_coef: float = 0.0, vf_coef: float = 0.5, @@ -151,6 +154,7 @@ def __init__( self.ls_alpha = ls_alpha self.ls_steps = ls_steps self.target_kl = target_kl + self.n_critic_updates = n_critic_updates if _init_setup_model: self._setup_model() @@ -165,6 +169,7 @@ def train(self) -> None: po_values = [] kl_divergences = [] line_search_results = [] + value_losses = [] continue_training = True @@ -187,7 +192,7 @@ def train(self) -> None: _ = self.policy.evaluate_actions(rollout_data.observations, actions) old_distribution = copy.copy(self.policy.get_distribution()) - values, log_prob, _ = self.policy.evaluate_actions( + _, log_prob, _ = self.policy.evaluate_actions( rollout_data.observations, actions ) distribution = self.policy.get_distribution() @@ -216,6 +221,7 @@ def train(self) -> None: grad_kl = [] grad_shape = [] params = [] + value_only_params = [] for param in self.policy.parameters(): kl_param_grad, *_ = torch.autograd.grad( kl_div, @@ -234,6 +240,8 @@ def train(self) -> None: grad_kl.append(kl_param_grad.view(-1)) g.append(g_grad.view(-1)) params.append(param) + else: + value_only_params.append(param) g = torch.cat(g) grad_kl = torch.cat(grad_kl) @@ -261,7 +269,7 @@ def Hpv(v, retain_graph=True): param.data += alpha * beta * s[j : (j + k)].view(shape) j += k - values, log_prob, _ = self.policy.evaluate_actions( + _, log_prob, _ = self.policy.evaluate_actions( rollout_data.observations, actions ) @@ -293,7 +301,23 @@ def Hpv(v, retain_graph=True): po_values.append(new_policy_obj.item()) kl_divergences.append(kl_div.item()) - # TODO: Critic training? + for _ in range(self.n_critic_updates): + values, _, _ = self.policy.evaluate_actions( + rollout_data.observations, actions + ) + values_pred = values.flatten() + value_loss = F.mse_loss(rollout_data.returns, values_pred) + value_losses.append(value_loss.item()) + + self.policy.optimizer.zero_grad() + value_loss.backward() + # Removing gradients of parameters shared with the actor + # otherwise it defeats the purposes of the KL constraint + for param in params: + param.grad = None + # Clip grad norm + th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) + self.policy.optimizer.step() if not continue_training: break @@ -305,6 +329,7 @@ def Hpv(v, retain_graph=True): # Logs self.logger.record("train/policy_objective_value", np.mean(po_values)) + self.logger.record("train/value_loss", np.mean(value_losses)) self.logger.record("train/kl_divergence_loss", np.mean(kl_divergences)) self.logger.record("train/explained_variance", explained_var) self.logger.record("train/line_search_success", np.mean(line_search_results)) From dc734625b39f4730fe360a8c4dd3832a8759e848 Mon Sep 17 00:00:00 2001 From: Cyprien C Date: Thu, 19 Aug 2021 08:37:51 +0100 Subject: [PATCH 05/28] Feat: adding TRPO algorithm (WIP) Changes around the alpha of the line search Adding TRPO to __init__ files --- sb3_contrib/__init__.py | 2 ++ sb3_contrib/trpo/__init__.py | 2 ++ sb3_contrib/trpo/trpo.py | 13 +++++-------- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 8f253e12..42b972d1 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -2,6 +2,8 @@ from sb3_contrib.qrdqn import QRDQN from sb3_contrib.tqc import TQC +from sb3_contrib.trpo import TRPO + # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") diff --git a/sb3_contrib/trpo/__init__.py b/sb3_contrib/trpo/__init__.py index e69de29b..7465a9d9 100644 --- a/sb3_contrib/trpo/__init__.py +++ b/sb3_contrib/trpo/__init__.py @@ -0,0 +1,2 @@ +from sb3_contrib.trpo.policies import CnnPolicy, MlpPolicy, MultiInputPolicy +from sb3_contrib.trpo.trpo import TRPO diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index 5ae8188f..e2c333ad 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -77,7 +77,7 @@ def __init__( gamma: float = 0.99, cg_max_steps: int = 10, cg_damping: float = 1e-3, - ls_alpha: float = 0.99, + ls_alpha: float = 0.8, ls_steps: int = 10, n_critic_updates: int = 5, gae_lambda: float = 0.95, @@ -256,7 +256,7 @@ def Hpv(v, retain_graph=True): beta /= torch.matmul(s, Hpv(s, retain_graph=False)) beta = torch.sqrt(beta) - alpha = self.ls_alpha + alpha = 1 orig_params = [param.detach().clone() for param in params] line_search_success = False @@ -264,9 +264,9 @@ def Hpv(v, retain_graph=True): for i in range(self.ls_steps): j = 0 - for param, shape in zip(params, grad_shape): + for param, orig_param, shape in zip(params, orig_params, grad_shape): k = param.numel() - param.data += alpha * beta * s[j : (j + k)].view(shape) + param.data = orig_param.data + alpha * beta * s[j : (j + k)].view(shape) j += k _, log_prob, _ = self.policy.evaluate_actions( @@ -284,10 +284,7 @@ def Hpv(v, retain_graph=True): line_search_success = True break - for param, orig_param in zip(params, orig_params): - param.data = orig_param.data.clone() - - alpha *= alpha + alpha *= self.ls_alpha line_search_results.append(line_search_success) From 9b8a22274f5bb876b7a3b83d1e22dcdff17cd580 Mon Sep 17 00:00:00 2001 From: Cyprien Date: Sat, 11 Sep 2021 11:07:28 +0100 Subject: [PATCH 06/28] feat: TRPO - addressing PR comments - renaming cg_solver to conjugate_gradient_solver and renaming parameter Avp_fun to matrix_vector_dot_func + docstring - extra comments + better variable names in trpo.py - defining a method for the hessian vector product instead of an inline function - fix registering correct policies for TRPO and using correct policy base in constructor --- sb3_contrib/common/policies.py | 482 +-------------------------------- sb3_contrib/common/utils.py | 27 +- sb3_contrib/trpo/policies.py | 11 +- sb3_contrib/trpo/trpo.py | 125 +++++---- 4 files changed, 102 insertions(+), 543 deletions(-) diff --git a/sb3_contrib/common/policies.py b/sb3_contrib/common/policies.py index 6a884797..7de70799 100644 --- a/sb3_contrib/common/policies.py +++ b/sb3_contrib/common/policies.py @@ -1,8 +1,6 @@ """Policies: abstract base class and concrete implementations.""" import collections -import copy -from abc import ABC, abstractmethod from functools import partial from typing import Any, Dict, List, Optional, Tuple, Type, Union @@ -20,325 +18,15 @@ StateDependentNoiseDistribution, make_proba_distribution, ) -from stable_baselines3.common.preprocessing import get_action_dim, is_image_space, maybe_transpose, preprocess_obs +from stable_baselines3.common.policies import BasePolicy, create_sde_features_extractor from stable_baselines3.common.torch_layers import ( BaseFeaturesExtractor, - CombinedExtractor, FlattenExtractor, MlpExtractor, NatureCNN, - create_mlp, + CombinedExtractor, ) from stable_baselines3.common.type_aliases import Schedule -from stable_baselines3.common.utils import get_device, is_vectorized_observation, obs_as_tensor - - -class BaseModel(nn.Module, ABC): - """ - The base model object: makes predictions in response to observations. - - In the case of policies, the prediction is an action. In the case of critics, it is the - estimated value of the observation. - - :param observation_space: The observation space of the environment - :param action_space: The action space of the environment - :param features_extractor_class: Features extractor to use. - :param features_extractor_kwargs: Keyword arguments - to pass to the features extractor. - :param features_extractor: Network to extract features - (a CNN when using images, a nn.Flatten() layer otherwise) - :param normalize_images: Whether to normalize images or not, - dividing by 255.0 (True by default) - :param optimizer_class: The optimizer to use, - ``th.optim.Adam`` by default - :param optimizer_kwargs: Additional keyword arguments, - excluding the learning rate, to pass to the optimizer - """ - - def __init__( - self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, - features_extractor: Optional[nn.Module] = None, - normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - ): - super(BaseModel, self).__init__() - - if optimizer_kwargs is None: - optimizer_kwargs = {} - - if features_extractor_kwargs is None: - features_extractor_kwargs = {} - - self.observation_space = observation_space - self.action_space = action_space - self.features_extractor = features_extractor - self.normalize_images = normalize_images - - self.optimizer_class = optimizer_class - self.optimizer_kwargs = optimizer_kwargs - self.optimizer = None # type: Optional[th.optim.Optimizer] - - self.features_extractor_class = features_extractor_class - self.features_extractor_kwargs = features_extractor_kwargs - - @abstractmethod - def forward(self, *args, **kwargs): - pass - - def _update_features_extractor( - self, - net_kwargs: Dict[str, Any], - features_extractor: Optional[BaseFeaturesExtractor] = None, - ) -> Dict[str, Any]: - """ - Update the network keyword arguments and create a new features extractor object if needed. - If a ``features_extractor`` object is passed, then it will be shared. - - :param net_kwargs: the base network keyword arguments, without the ones - related to features extractor - :param features_extractor: a features extractor object. - If None, a new object will be created. - :return: The updated keyword arguments - """ - net_kwargs = net_kwargs.copy() - if features_extractor is None: - # The features extractor is not shared, create a new one - features_extractor = self.make_features_extractor() - net_kwargs.update(dict(features_extractor=features_extractor, features_dim=features_extractor.features_dim)) - return net_kwargs - - def make_features_extractor(self) -> BaseFeaturesExtractor: - """Helper method to create a features extractor.""" - return self.features_extractor_class(self.observation_space, **self.features_extractor_kwargs) - - def extract_features(self, obs: th.Tensor) -> th.Tensor: - """ - Preprocess the observation if needed and extract features. - - :param obs: - :return: - """ - assert self.features_extractor is not None, "No features extractor was set" - preprocessed_obs = preprocess_obs(obs, self.observation_space, normalize_images=self.normalize_images) - return self.features_extractor(preprocessed_obs) - - def _get_constructor_parameters(self) -> Dict[str, Any]: - """ - Get data that need to be saved in order to re-create the model when loading it from disk. - - :return: The dictionary to pass to the as kwargs constructor when reconstruction this model. - """ - return dict( - observation_space=self.observation_space, - action_space=self.action_space, - # Passed to the constructor by child class - # squash_output=self.squash_output, - # features_extractor=self.features_extractor - normalize_images=self.normalize_images, - ) - - @property - def device(self) -> th.device: - """Infer which device this policy lives on by inspecting its parameters. - If it has no parameters, the 'cpu' device is used as a fallback. - - :return:""" - for param in self.parameters(): - return param.device - return get_device("cpu") - - def save(self, path: str) -> None: - """ - Save model to a given location. - - :param path: - """ - th.save({"state_dict": self.state_dict(), "data": self._get_constructor_parameters()}, path) - - @classmethod - def load(cls, path: str, device: Union[th.device, str] = "auto") -> "BaseModel": - """ - Load model from path. - - :param path: - :param device: Device on which the policy should be loaded. - :return: - """ - device = get_device(device) - saved_variables = th.load(path, map_location=device) - # Create policy object - model = cls(**saved_variables["data"]) # pytype: disable=not-instantiable - # Load weights - model.load_state_dict(saved_variables["state_dict"]) - model.to(device) - return model - - def load_from_vector(self, vector: np.ndarray) -> None: - """ - Load parameters from a 1D vector. - - :param vector: - """ - th.nn.utils.vector_to_parameters(th.FloatTensor(vector).to(self.device), self.parameters()) - - def parameters_to_vector(self) -> np.ndarray: - """ - Convert the parameters to a 1D vector. - - :return: - """ - return th.nn.utils.parameters_to_vector(self.parameters()).detach().cpu().numpy() - - -class BasePolicy(BaseModel): - """The base policy object. - - Parameters are mostly the same as `BaseModel`; additions are documented below. - - :param args: positional arguments passed through to `BaseModel`. - :param kwargs: keyword arguments passed through to `BaseModel`. - :param squash_output: For continuous actions, whether the output is squashed - or not using a ``tanh()`` function. - """ - - def __init__(self, *args, squash_output: bool = False, **kwargs): - super(BasePolicy, self).__init__(*args, **kwargs) - self._squash_output = squash_output - - @staticmethod - def _dummy_schedule(progress_remaining: float) -> float: - """(float) Useful for pickling policy.""" - del progress_remaining - return 0.0 - - @property - def squash_output(self) -> bool: - """(bool) Getter for squash_output.""" - return self._squash_output - - @staticmethod - def init_weights(module: nn.Module, gain: float = 1) -> None: - """ - Orthogonal initialization (used in PPO and A2C) - """ - if isinstance(module, (nn.Linear, nn.Conv2d)): - nn.init.orthogonal_(module.weight, gain=gain) - if module.bias is not None: - module.bias.data.fill_(0.0) - - @abstractmethod - def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: - """ - Get the action according to the policy for a given observation. - - By default provides a dummy implementation -- not all BasePolicy classes - implement this, e.g. if they are a Critic in an Actor-Critic method. - - :param observation: - :param deterministic: Whether to use stochastic or deterministic actions - :return: Taken action according to the policy - """ - - def predict( - self, - observation: Union[np.ndarray, Dict[str, np.ndarray]], - state: Optional[np.ndarray] = None, - mask: Optional[np.ndarray] = None, - deterministic: bool = False, - ) -> Tuple[np.ndarray, Optional[np.ndarray]]: - """ - Get the policy action and state from an observation (and optional state). - Includes sugar-coating to handle different observations (e.g. normalizing images). - - :param observation: the input observation - :param state: The last states (can be None, used in recurrent policies) - :param mask: The last masks (can be None, used in recurrent policies) - :param deterministic: Whether or not to return deterministic actions. - :return: the model's action and the next state - (used in recurrent policies) - """ - # TODO (GH/1): add support for RNN policies - # if state is None: - # state = self.initial_state - # if mask is None: - # mask = [False for _ in range(self.n_envs)] - - vectorized_env = False - if isinstance(observation, dict): - # need to copy the dict as the dict in VecFrameStack will become a torch tensor - observation = copy.deepcopy(observation) - for key, obs in observation.items(): - obs_space = self.observation_space.spaces[key] - if is_image_space(obs_space): - obs_ = maybe_transpose(obs, obs_space) - else: - obs_ = np.array(obs) - vectorized_env = vectorized_env or is_vectorized_observation(obs_, obs_space) - # Add batch dimension if needed - observation[key] = obs_.reshape((-1,) + self.observation_space[key].shape) - - elif is_image_space(self.observation_space): - # Handle the different cases for images - # as PyTorch use channel first format - observation = maybe_transpose(observation, self.observation_space) - - else: - observation = np.array(observation) - - if not isinstance(observation, dict): - # Dict obs need to be handled separately - vectorized_env = is_vectorized_observation(observation, self.observation_space) - # Add batch dimension if needed - observation = observation.reshape((-1,) + self.observation_space.shape) - - observation = obs_as_tensor(observation, self.device) - - with th.no_grad(): - actions = self._predict(observation, deterministic=deterministic) - # Convert to numpy - actions = actions.cpu().numpy() - - if isinstance(self.action_space, gym.spaces.Box): - if self.squash_output: - # Rescale to proper domain when using squashing - actions = self.unscale_action(actions) - else: - # Actions could be on arbitrary scale, so clip the actions to avoid - # out of bound error (e.g. if sampling from a Gaussian distribution) - actions = np.clip(actions, self.action_space.low, self.action_space.high) - - if not vectorized_env: - if state is not None: - raise ValueError("Error: The environment must be vectorized when using recurrent policies.") - actions = actions[0] - - return actions, state - - def scale_action(self, action: np.ndarray) -> np.ndarray: - """ - Rescale the action from [low, high] to [-1, 1] - (no need for symmetric action space) - - :param action: Action to scale - :return: Scaled action - """ - low, high = self.action_space.low, self.action_space.high - return 2.0 * ((action - low) / (high - low)) - 1.0 - - def unscale_action(self, scaled_action: np.ndarray) -> np.ndarray: - """ - Rescale the action from [-1, 1] to [low, high] - (no need for symmetric action space) - - :param scaled_action: Action to un-scale - """ - low, high = self.action_space.low, self.action_space.high - return low + (0.5 * (scaled_action + 1.0) * (high - low)) class ActorCriticPolicy(BasePolicy): @@ -521,7 +209,9 @@ def _build(self, lr_schedule: Schedule) -> None: elif isinstance(self.action_dist, StateDependentNoiseDistribution): latent_sde_dim = latent_dim_pi if self.sde_net_arch is None else latent_sde_dim self.action_net, self.log_std = self.action_dist.proba_distribution_net( - latent_dim=latent_dim_pi, latent_sde_dim=latent_sde_dim, log_std_init=self.log_std_init + latent_dim=latent_dim_pi, + latent_sde_dim=latent_sde_dim, + log_std_init=self.log_std_init, ) elif isinstance(self.action_dist, CategoricalDistribution): self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) @@ -797,165 +487,3 @@ def __init__( optimizer_class, optimizer_kwargs, ) - - -class ContinuousCritic(BaseModel): - """ - Critic network(s) for DDPG/SAC/TD3. - It represents the action-state value function (Q-value function). - Compared to A2C/PPO critics, this one represents the Q-value - and takes the continuous action as input. It is concatenated with the state - and then fed to the network which outputs a single value: Q(s, a). - For more recent algorithms like SAC/TD3, multiple networks - are created to give different estimates. - - By default, it creates two critic networks used to reduce overestimation - thanks to clipped Q-learning (cf TD3 paper). - - :param observation_space: Obervation space - :param action_space: Action space - :param net_arch: Network architecture - :param features_extractor: Network to extract features - (a CNN when using images, a nn.Flatten() layer otherwise) - :param features_dim: Number of features - :param activation_fn: Activation function - :param normalize_images: Whether to normalize images or not, - dividing by 255.0 (True by default) - :param n_critics: Number of critic networks to create. - :param share_features_extractor: Whether the features extractor is shared or not - between the actor and the critic (this saves computation time) - """ - - def __init__( - self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - net_arch: List[int], - features_extractor: nn.Module, - features_dim: int, - activation_fn: Type[nn.Module] = nn.ReLU, - normalize_images: bool = True, - n_critics: int = 2, - share_features_extractor: bool = True, - ): - super().__init__( - observation_space, - action_space, - features_extractor=features_extractor, - normalize_images=normalize_images, - ) - - action_dim = get_action_dim(self.action_space) - - self.share_features_extractor = share_features_extractor - self.n_critics = n_critics - self.q_networks = [] - for idx in range(n_critics): - q_net = create_mlp(features_dim + action_dim, 1, net_arch, activation_fn) - q_net = nn.Sequential(*q_net) - self.add_module(f"qf{idx}", q_net) - self.q_networks.append(q_net) - - def forward(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, ...]: - # Learn the features extractor using the policy loss only - # when the features_extractor is shared with the actor - with th.set_grad_enabled(not self.share_features_extractor): - features = self.extract_features(obs) - qvalue_input = th.cat([features, actions], dim=1) - return tuple(q_net(qvalue_input) for q_net in self.q_networks) - - def q1_forward(self, obs: th.Tensor, actions: th.Tensor) -> th.Tensor: - """ - Only predict the Q-value using the first network. - This allows to reduce computation when all the estimates are not needed - (e.g. when updating the policy in TD3). - """ - with th.no_grad(): - features = self.extract_features(obs) - return self.q_networks[0](th.cat([features, actions], dim=1)) - - -def create_sde_features_extractor( - features_dim: int, sde_net_arch: List[int], activation_fn: Type[nn.Module] -) -> Tuple[nn.Sequential, int]: - """ - Create the neural network that will be used to extract features - for the gSDE exploration function. - - :param features_dim: - :param sde_net_arch: - :param activation_fn: - :return: - """ - # Special case: when using states as features (i.e. sde_net_arch is an empty list) - # don't use any activation function - sde_activation = activation_fn if len(sde_net_arch) > 0 else None - latent_sde_net = create_mlp(features_dim, -1, sde_net_arch, activation_fn=sde_activation, squash_output=False) - latent_sde_dim = sde_net_arch[-1] if len(sde_net_arch) > 0 else features_dim - sde_features_extractor = nn.Sequential(*latent_sde_net) - return sde_features_extractor, latent_sde_dim - - -_policy_registry = dict() # type: Dict[Type[BasePolicy], Dict[str, Type[BasePolicy]]] - - -def get_policy_from_name(base_policy_type: Type[BasePolicy], name: str) -> Type[BasePolicy]: - """ - Returns the registered policy from the base type and name. - See `register_policy` for registering policies and explanation. - - :param base_policy_type: the base policy class - :param name: the policy name - :return: the policy - """ - if base_policy_type not in _policy_registry: - raise KeyError(f"Error: the policy type {base_policy_type} is not registered!") - if name not in _policy_registry[base_policy_type]: - raise KeyError( - f"Error: unknown policy type {name}," - f"the only registed policy type are: {list(_policy_registry[base_policy_type].keys())}!" - ) - return _policy_registry[base_policy_type][name] - - -def register_policy(name: str, policy: Type[BasePolicy]) -> None: - """ - Register a policy, so it can be called using its name. - e.g. SAC('MlpPolicy', ...) instead of SAC(MlpPolicy, ...). - - The goal here is to standardize policy naming, e.g. - all algorithms can call upon "MlpPolicy" or "CnnPolicy", - and they receive respective policies that work for them. - Consider following: - - OnlinePolicy - -- OnlineMlpPolicy ("MlpPolicy") - -- OnlineCnnPolicy ("CnnPolicy") - OfflinePolicy - -- OfflineMlpPolicy ("MlpPolicy") - -- OfflineCnnPolicy ("CnnPolicy") - - Two policies have name "MlpPolicy" and two have "CnnPolicy". - In `get_policy_from_name`, the parent class (e.g. OnlinePolicy) - is given and used to select and return the correct policy. - - :param name: the policy name - :param policy: the policy class - """ - sub_class = None - for cls in BasePolicy.__subclasses__(): - if issubclass(policy, cls): - sub_class = cls - break - if sub_class is None: - raise ValueError(f"Error: the policy {policy} is not of any known subclasses of BasePolicy!") - - if sub_class not in _policy_registry: - _policy_registry[sub_class] = {} - if name in _policy_registry[sub_class]: - # Check if the registered policy is same - # we try to register. If not so, - # do not override and complain. - if _policy_registry[sub_class][name] != policy: - raise ValueError(f"Error: the name {name} is already registered for a different policy, will not override.") - _policy_registry[sub_class][name] = policy diff --git a/sb3_contrib/common/utils.py b/sb3_contrib/common/utils.py index 89a1963d..2f4a9694 100644 --- a/sb3_contrib/common/utils.py +++ b/sb3_contrib/common/utils.py @@ -71,21 +71,26 @@ def quantile_huber_loss( # TODO: write regression tests -def cg_solver(Avp_fun: Callable[[th.Tensor], th.Tensor], b, max_iter=10, residual_tol=1e-10) -> th.Tensor: +def conjugate_gradient_solver( + matrix_vector_dot_func: Callable[[th.Tensor], th.Tensor], + b, + max_iter=10, + residual_tol=1e-10, +) -> th.Tensor: """ Finds an approximate solution to a set of linear equations Ax = b Source: https://github.com/ajlangley/trpo-pytorch/blob/master/conjugate_gradient.py - :param Avp_fun : callable + :param matrix_vector_dot_func: a function that right multiplies a matrix A by a vector v - :param b : torch.FloatTensor + :param b: the right hand term in the set of linear equations Ax = b - :param max_iter : int + :param max_iter: the maximum number of iterations (default is 10) - :param residual_tol: float + :param residual_tol: residual tolerance for early stopping of the solving (default is 1e-10) - :return x : torch.FloatTensor + :return x: the approximate solution to the system of equations defined by Avp_fun and b """ @@ -93,7 +98,7 @@ def cg_solver(Avp_fun: Callable[[th.Tensor], th.Tensor], b, max_iter=10, residua # The vector is not initialized at 0 because of the instability issues when the gradient becomes small. # A small random gaussian noise is used for the initialization. x = 1e-4 * th.randn_like(b) - r = b - Avp_fun(x) + r = b - matrix_vector_dot_func(x) r_dot = th.matmul(r, r) if r_dot < residual_tol: @@ -105,7 +110,7 @@ def cg_solver(Avp_fun: Callable[[th.Tensor], th.Tensor], b, max_iter=10, residua p = r.clone() for i in range(max_iter): - Avp = Avp_fun(p) + Avp = matrix_vector_dot_func(p) alpha = r_dot / p.dot(Avp) x += alpha * p @@ -144,6 +149,10 @@ def flat_grad( :return: Tensor containing the flattened gradients """ grads = th.autograd.grad( - output, parameters, create_graph=create_graph, retain_graph=retain_graph, allow_unused=True + output, + parameters, + create_graph=create_graph, + retain_graph=retain_graph, + allow_unused=True, ) return th.cat([grad.view(-1) for grad in grads if grad is not None]) diff --git a/sb3_contrib/trpo/policies.py b/sb3_contrib/trpo/policies.py index 7427cfc4..7aef8718 100644 --- a/sb3_contrib/trpo/policies.py +++ b/sb3_contrib/trpo/policies.py @@ -1,11 +1,8 @@ # This file is here just to define MlpPolicy/CnnPolicy -# that work for PPO -from stable_baselines3.common.policies import ( - ActorCriticCnnPolicy, - ActorCriticPolicy, - MultiInputActorCriticPolicy, - register_policy, -) +# that work for TRPO +from sb3_contrib.common.policies import ActorCriticPolicy, ActorCriticCnnPolicy, MultiInputActorCriticPolicy +from stable_baselines3.common.policies import register_policy + MlpPolicy = ActorCriticPolicy CnnPolicy = ActorCriticCnnPolicy diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index e2c333ad..8b56f191 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -1,16 +1,17 @@ import copy import warnings -from typing import Any, Dict, Optional, Type, Union +from functools import partial +from typing import Any, Dict, Optional, Type, Union, List import numpy as np -import torch import torch as th from gym import spaces +from torch import nn from torch.distributions import kl_divergence from torch.nn import functional as F from sb3_contrib.common.policies import ActorCriticPolicy -from sb3_contrib.common.utils import flat_grad, cg_solver +from sb3_contrib.common.utils import flat_grad, conjugate_gradient_solver from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance @@ -108,6 +109,7 @@ def __init__( max_grad_norm=max_grad_norm, use_sde=use_sde, sde_sample_freq=sde_sample_freq, + policy_base=ActorCriticPolicy, tensorboard_log=tensorboard_log, policy_kwargs=policy_kwargs, verbose=verbose, @@ -188,19 +190,15 @@ def train(self) -> None: if self.use_sde: self.policy.reset_noise(self.batch_size) - with torch.no_grad(): + with th.no_grad(): _ = self.policy.evaluate_actions(rollout_data.observations, actions) old_distribution = copy.copy(self.policy.get_distribution()) - _, log_prob, _ = self.policy.evaluate_actions( - rollout_data.observations, actions - ) + _, log_prob, _ = self.policy.evaluate_actions(rollout_data.observations, actions) distribution = self.policy.get_distribution() advantages = rollout_data.advantages - advantages = (advantages - advantages.mean()) / ( - rollout_data.advantages.std() + 1e-8 - ) + advantages = (advantages - advantages.mean()) / (rollout_data.advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) @@ -209,21 +207,26 @@ def train(self) -> None: policy_obj = (advantages * ratio).mean() # KL divergence - kl_div = kl_divergence( - distribution.distribution, old_distribution.distribution - ).mean() + kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean() # Surrogate & KL gradient self.policy.optimizer.zero_grad() # This is necessary because not all the parameters in the policy have gradients w.r.t. the KL divergence - g = [] + policy_obj_gradient = [] + # Contains the gradients of the KL divergence grad_kl = [] + # Contains the shape of the gradients of the KL divergence w.r.t each parameter + # This way the flattened gradient can be reshaped back into the original shapes and applied to + # the parameters grad_shape = [] + # Contains the parameters which have non-zeros KL divergence gradients + # The list is used during the line-search to apply the step to each parameters params = [] - value_only_params = [] + for param in self.policy.parameters(): - kl_param_grad, *_ = torch.autograd.grad( + # For each parameter we compute the gradient of the KL divergence w.r.t to that parameter + kl_param_grad, *_ = th.autograd.grad( kl_div, param, create_graph=True, @@ -231,64 +234,75 @@ def train(self) -> None: allow_unused=True, only_inputs=True, ) + # If the gradient is not zero (not None), we store the parameter in the params list + # and add the gradient and its shape to grad_kl and grad_shape respectively if kl_param_grad is not None: - g_grad, *_ = torch.autograd.grad( - policy_obj, param, retain_graph=True, only_inputs=True - ) + # If the parameter impacts the KL divergence (i.e. the policy) + # we compute the gradient of the policy objective w.r.t to the parameter + # this avoids computing the gradient if it's not going to be used in the conjugate gradient step + g_grad, *_ = th.autograd.grad(policy_obj, param, retain_graph=True, only_inputs=True) grad_shape.append(kl_param_grad.shape) grad_kl.append(kl_param_grad.view(-1)) - g.append(g_grad.view(-1)) + policy_obj_gradient.append(g_grad.view(-1)) params.append(param) - else: - value_only_params.append(param) - g = torch.cat(g) - grad_kl = torch.cat(grad_kl) + # Gradients are concatenated before the conjugate gradient step + policy_obj_gradient = th.cat(policy_obj_gradient) + grad_kl = th.cat(grad_kl) - def Hpv(v, retain_graph=True): - jvp = (grad_kl * v).sum() - return flat_grad(jvp, params, retain_graph=retain_graph) + self.cg_damping * v + # Hessian-vector dot product function used in the conjugate gradient step + hvp = partial(self.hessian_vector_product, params, grad_kl) - s = cg_solver(Hpv, g, max_iter=self.cg_max_steps) + # Computing search direction + search_direction = conjugate_gradient_solver( + hvp, + policy_obj_gradient, + max_iter=self.cg_max_steps, + ) + # Maximal step length beta = 2 * self.target_kl - beta /= torch.matmul(s, Hpv(s, retain_graph=False)) - beta = torch.sqrt(beta) + beta /= th.matmul(search_direction, hvp(search_direction, retain_graph=False)) + beta = th.sqrt(beta) alpha = 1 orig_params = [param.detach().clone() for param in params] - line_search_success = False - with torch.no_grad(): - for i in range(self.ls_steps): + is_line_search_success = False + with th.no_grad(): + # Line-search + for _ in range(self.ls_steps): j = 0 + # Applying the scaled step direction for param, orig_param, shape in zip(params, orig_params, grad_shape): k = param.numel() - param.data = orig_param.data + alpha * beta * s[j : (j + k)].view(shape) + param.data = orig_param.data + alpha * beta * search_direction[j : (j + k)].view(shape) j += k - _, log_prob, _ = self.policy.evaluate_actions( - rollout_data.observations, actions - ) + # Recomputing the policy log-probabilities + _, log_prob, _ = self.policy.evaluate_actions(rollout_data.observations, actions) + # New policy objective ratio = th.exp(log_prob - rollout_data.old_log_prob) new_policy_obj = (advantages * ratio).mean() - kl_div = kl_divergence( - distribution.distribution, old_distribution.distribution - ).mean() + # New KL-divergence + kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean() + # Constraint criteria if (kl_div < self.target_kl) and (new_policy_obj > policy_obj): - line_search_success = True + is_line_search_success = True break + # Reducing step size if line-search wasn't successful alpha *= self.ls_alpha - line_search_results.append(line_search_success) + line_search_results.append(is_line_search_success) - if not line_search_success: + if not is_line_search_success: + # If the line-search wasn't successful we revert to the original parameters for param, orig_param in zip(params, orig_params): param.data = orig_param.data.clone() @@ -298,10 +312,9 @@ def Hpv(v, retain_graph=True): po_values.append(new_policy_obj.item()) kl_divergences.append(kl_div.item()) + # Critic updates for _ in range(self.n_critic_updates): - values, _, _ = self.policy.evaluate_actions( - rollout_data.observations, actions - ) + values, _, _ = self.policy.evaluate_actions(rollout_data.observations, actions) values_pred = values.flatten() value_loss = F.mse_loss(rollout_data.returns, values_pred) value_losses.append(value_loss.item()) @@ -320,21 +333,33 @@ def Hpv(v, retain_graph=True): break self._n_updates += self.n_epochs - explained_var = explained_variance( - self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten() - ) + explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) # Logs self.logger.record("train/policy_objective_value", np.mean(po_values)) self.logger.record("train/value_loss", np.mean(value_losses)) self.logger.record("train/kl_divergence_loss", np.mean(kl_divergences)) self.logger.record("train/explained_variance", explained_var) - self.logger.record("train/line_search_success", np.mean(line_search_results)) + self.logger.record("train/is_line_search_success", np.mean(line_search_results)) if hasattr(self.policy, "log_std"): self.logger.record("train/std", th.exp(self.policy.log_std).mean().item()) self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") + def hessian_vector_product( + self, params: List[nn.Parameter], grad_kl: th.Tensor, v: th.Tensor, retain_graph: bool = True + ) -> th.Tensor: + """ + Computes the matrix-vector product with the Fisher information matrix + :param params: list of parameters used to compute the Hessian + :param grad_kl: flattened gradient of the KL divergence between the old and new policy + :param v: vector to compute the dot product the hessian-vector dot product with + :param retain_graph: if True, the graph will be kept after computing the Hessian + :return: Hessian-vector dot product + """ + jvp = (grad_kl * v).sum() + return flat_grad(jvp, params, retain_graph=retain_graph) + self.cg_damping * v + def learn( self, total_timesteps: int, From 869dce9c86f5019c8dd45370d6ad5c62ac8c5677 Mon Sep 17 00:00:00 2001 From: Cyprien Date: Sat, 11 Sep 2021 11:11:51 +0100 Subject: [PATCH 07/28] refactor: TRPO - policier - refactoring sb3_contrib.common.policies to reuse as much code as possible from sb3 --- sb3_contrib/common/policies.py | 475 +-------------------------------- 1 file changed, 7 insertions(+), 468 deletions(-) diff --git a/sb3_contrib/common/policies.py b/sb3_contrib/common/policies.py index 7de70799..75012f5d 100644 --- a/sb3_contrib/common/policies.py +++ b/sb3_contrib/common/policies.py @@ -1,336 +1,15 @@ """Policies: abstract base class and concrete implementations.""" -import collections -from functools import partial -from typing import Any, Dict, List, Optional, Tuple, Type, Union +from stable_baselines3.common.distributions import Distribution +from stable_baselines3.common.policies import ActorCriticPolicy as _ActorCriticPolicy -import gym -import numpy as np -import torch as th -from torch import nn -from stable_baselines3.common.distributions import ( - BernoulliDistribution, - CategoricalDistribution, - DiagGaussianDistribution, - Distribution, - MultiCategoricalDistribution, - StateDependentNoiseDistribution, - make_proba_distribution, -) -from stable_baselines3.common.policies import BasePolicy, create_sde_features_extractor -from stable_baselines3.common.torch_layers import ( - BaseFeaturesExtractor, - FlattenExtractor, - MlpExtractor, - NatureCNN, - CombinedExtractor, -) -from stable_baselines3.common.type_aliases import Schedule - - -class ActorCriticPolicy(BasePolicy): +class ActorCriticPolicy(_ActorCriticPolicy): """ Policy class for actor-critic algorithms (has both policy and value prediction). Used by A2C, PPO and the likes. - - :param observation_space: Observation space - :param action_space: Action space - :param lr_schedule: Learning rate schedule (could be constant) - :param net_arch: The specification of the policy and value networks. - :param activation_fn: Activation function - :param ortho_init: Whether to use or not orthogonal initialization - :param use_sde: Whether to use State Dependent Exploration or not - :param log_std_init: Initial value for the log standard deviation - :param full_std: Whether to use (n_features x n_actions) parameters - for the std instead of only (n_features,) when using gSDE - :param sde_net_arch: Network architecture for extracting features - when using gSDE. If None, the latent features from the policy will be used. - Pass an empty list to use the states as features. - :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure - a positive standard deviation (cf paper). It allows to keep variance - above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. - :param squash_output: Whether to squash the output using a tanh function, - this allows to ensure boundaries when using gSDE. - :param features_extractor_class: Features extractor to use. - :param features_extractor_kwargs: Keyword arguments - to pass to the features extractor. - :param normalize_images: Whether to normalize images or not, - dividing by 255.0 (True by default) - :param optimizer_class: The optimizer to use, - ``th.optim.Adam`` by default - :param optimizer_kwargs: Additional keyword arguments, - excluding the learning rate, to pass to the optimizer """ - def __init__( - self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - lr_schedule: Schedule, - net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, - activation_fn: Type[nn.Module] = nn.Tanh, - ortho_init: bool = True, - use_sde: bool = False, - log_std_init: float = 0.0, - full_std: bool = True, - sde_net_arch: Optional[List[int]] = None, - use_expln: bool = False, - squash_output: bool = False, - features_extractor_class: Type[BaseFeaturesExtractor] = FlattenExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, - normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - ): - - if optimizer_kwargs is None: - optimizer_kwargs = {} - # Small values to avoid NaN in Adam optimizer - if optimizer_class == th.optim.Adam: - optimizer_kwargs["eps"] = 1e-5 - - super(ActorCriticPolicy, self).__init__( - observation_space, - action_space, - features_extractor_class, - features_extractor_kwargs, - optimizer_class=optimizer_class, - optimizer_kwargs=optimizer_kwargs, - squash_output=squash_output, - ) - - # Default network architecture, from stable-baselines - if net_arch is None: - if features_extractor_class == NatureCNN: - net_arch = [] - else: - net_arch = [dict(pi=[64, 64], vf=[64, 64])] - - self.net_arch = net_arch - self.activation_fn = activation_fn - self.ortho_init = ortho_init - - self.features_extractor = features_extractor_class(self.observation_space, **self.features_extractor_kwargs) - self.features_dim = self.features_extractor.features_dim - - self.normalize_images = normalize_images - self.log_std_init = log_std_init - dist_kwargs = None - # Keyword arguments for gSDE distribution - if use_sde: - dist_kwargs = { - "full_std": full_std, - "squash_output": squash_output, - "use_expln": use_expln, - "learn_features": sde_net_arch is not None, - } - - self.sde_features_extractor = None - self.sde_net_arch = sde_net_arch - self.use_sde = use_sde - self.dist_kwargs = dist_kwargs - - # Action distribution - self.action_dist = make_proba_distribution(action_space, use_sde=use_sde, dist_kwargs=dist_kwargs) - - self._build(lr_schedule) - - def _get_constructor_parameters(self) -> Dict[str, Any]: - data = super()._get_constructor_parameters() - - default_none_kwargs = self.dist_kwargs or collections.defaultdict(lambda: None) - - data.update( - dict( - net_arch=self.net_arch, - activation_fn=self.activation_fn, - use_sde=self.use_sde, - log_std_init=self.log_std_init, - squash_output=default_none_kwargs["squash_output"], - full_std=default_none_kwargs["full_std"], - sde_net_arch=self.sde_net_arch, - use_expln=default_none_kwargs["use_expln"], - lr_schedule=self._dummy_schedule, # dummy lr schedule, not needed for loading policy alone - ortho_init=self.ortho_init, - optimizer_class=self.optimizer_class, - optimizer_kwargs=self.optimizer_kwargs, - features_extractor_class=self.features_extractor_class, - features_extractor_kwargs=self.features_extractor_kwargs, - ) - ) - return data - - def reset_noise(self, n_envs: int = 1) -> None: - """ - Sample new weights for the exploration matrix. - - :param n_envs: - """ - assert isinstance(self.action_dist, StateDependentNoiseDistribution), "reset_noise() is only available when using gSDE" - self.action_dist.sample_weights(self.log_std, batch_size=n_envs) - - def _build_mlp_extractor(self) -> None: - """ - Create the policy and value networks. - Part of the layers can be shared. - """ - # Note: If net_arch is None and some features extractor is used, - # net_arch here is an empty list and mlp_extractor does not - # really contain any layers (acts like an identity module). - self.mlp_extractor = MlpExtractor( - self.features_dim, - net_arch=self.net_arch, - activation_fn=self.activation_fn, - device=self.device, - ) - - def _build(self, lr_schedule: Schedule) -> None: - """ - Create the networks and the optimizer. - - :param lr_schedule: Learning rate schedule - lr_schedule(1) is the initial learning rate - """ - self._build_mlp_extractor() - - latent_dim_pi = self.mlp_extractor.latent_dim_pi - - # Separate features extractor for gSDE - if self.sde_net_arch is not None: - self.sde_features_extractor, latent_sde_dim = create_sde_features_extractor( - self.features_dim, self.sde_net_arch, self.activation_fn - ) - - if isinstance(self.action_dist, DiagGaussianDistribution): - self.action_net, self.log_std = self.action_dist.proba_distribution_net( - latent_dim=latent_dim_pi, log_std_init=self.log_std_init - ) - elif isinstance(self.action_dist, StateDependentNoiseDistribution): - latent_sde_dim = latent_dim_pi if self.sde_net_arch is None else latent_sde_dim - self.action_net, self.log_std = self.action_dist.proba_distribution_net( - latent_dim=latent_dim_pi, - latent_sde_dim=latent_sde_dim, - log_std_init=self.log_std_init, - ) - elif isinstance(self.action_dist, CategoricalDistribution): - self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) - elif isinstance(self.action_dist, MultiCategoricalDistribution): - self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) - elif isinstance(self.action_dist, BernoulliDistribution): - self.action_net = self.action_dist.proba_distribution_net(latent_dim=latent_dim_pi) - else: - raise NotImplementedError(f"Unsupported distribution '{self.action_dist}'.") - - self.value_net = nn.Linear(self.mlp_extractor.latent_dim_vf, 1) - # Init weights: use orthogonal initialization - # with small initial weight for the output - if self.ortho_init: - # TODO: check for features_extractor - # Values from stable-baselines. - # features_extractor/mlp values are - # originally from openai/baselines (default gains/init_scales). - module_gains = { - self.features_extractor: np.sqrt(2), - self.mlp_extractor: np.sqrt(2), - self.action_net: 0.01, - self.value_net: 1, - } - for module, gain in module_gains.items(): - module.apply(partial(self.init_weights, gain=gain)) - - # Setup optimizer with initial learning rate - self.optimizer = self.optimizer_class(self.parameters(), lr=lr_schedule(1), **self.optimizer_kwargs) - - def forward(self, obs: th.Tensor, deterministic: bool = False) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: - """ - Forward pass in all the networks (actor and critic) - - :param obs: Observation - :param deterministic: Whether to sample or use deterministic actions - :return: action, value and log probability of the action - """ - latent_pi, latent_vf, latent_sde = self._get_latent(obs) - # Evaluate the values for the given observations - values = self.value_net(latent_vf) - distribution = self._get_action_dist_from_latent(latent_pi, latent_sde=latent_sde) - actions = distribution.get_actions(deterministic=deterministic) - log_prob = distribution.log_prob(actions) - return actions, values, log_prob - - def _get_latent(self, obs: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: - """ - Get the latent code (i.e., activations of the last layer of each network) - for the different networks. - - :param obs: Observation - :return: Latent codes - for the actor, the value function and for gSDE function - """ - # Preprocess the observation if needed - features = self.extract_features(obs) - latent_pi, latent_vf = self.mlp_extractor(features) - - # Features for sde - latent_sde = latent_pi - if self.sde_features_extractor is not None: - latent_sde = self.sde_features_extractor(features) - return latent_pi, latent_vf, latent_sde - - def _get_action_dist_from_latent(self, latent_pi: th.Tensor, latent_sde: Optional[th.Tensor] = None) -> Distribution: - """ - Retrieve action distribution given the latent codes. - - :param latent_pi: Latent code for the actor - :param latent_sde: Latent code for the gSDE exploration function - :return: Action distribution - """ - mean_actions = self.action_net(latent_pi) - - if isinstance(self.action_dist, DiagGaussianDistribution): - return self.action_dist.proba_distribution(mean_actions, self.log_std) - elif isinstance(self.action_dist, CategoricalDistribution): - # Here mean_actions are the logits before the softmax - return self.action_dist.proba_distribution(action_logits=mean_actions) - elif isinstance(self.action_dist, MultiCategoricalDistribution): - # Here mean_actions are the flattened logits - return self.action_dist.proba_distribution(action_logits=mean_actions) - elif isinstance(self.action_dist, BernoulliDistribution): - # Here mean_actions are the logits (before rounding to get the binary actions) - return self.action_dist.proba_distribution(action_logits=mean_actions) - elif isinstance(self.action_dist, StateDependentNoiseDistribution): - return self.action_dist.proba_distribution(mean_actions, self.log_std, latent_sde) - else: - raise ValueError("Invalid action distribution") - - def _predict(self, observation: th.Tensor, deterministic: bool = False) -> th.Tensor: - """ - Get the action according to the policy for a given observation. - - :param observation: - :param deterministic: Whether to use stochastic or deterministic actions - :return: Taken action according to the policy - """ - latent_pi, _, latent_sde = self._get_latent(observation) - distribution = self._get_action_dist_from_latent(latent_pi, latent_sde) - return distribution.get_actions(deterministic=deterministic) - - def evaluate_actions(self, obs: th.Tensor, actions: th.Tensor) -> Tuple[th.Tensor, th.Tensor, th.Tensor]: - """ - Evaluate actions according to the current policy, - given the observations. - - :param obs: - :param actions: - :return: estimated value, log likelihood of taking those actions - and entropy of the action distribution. - """ - latent_pi, latent_vf, latent_sde = self._get_latent(obs) - distribution = self._get_action_dist_from_latent(latent_pi, latent_sde) - log_prob = distribution.log_prob(actions) - values = self.value_net(latent_vf) - return values, log_prob, distribution.entropy() - def get_distribution(self) -> Distribution: """ Get the current action distribution @@ -339,151 +18,11 @@ def get_distribution(self) -> Distribution: return self.action_dist +# This is just to propagate get_distribution class ActorCriticCnnPolicy(ActorCriticPolicy): - """ - CNN policy class for actor-critic algorithms (has both policy and value prediction). - Used by A2C, PPO and the likes. - - :param observation_space: Observation space - :param action_space: Action space - :param lr_schedule: Learning rate schedule (could be constant) - :param net_arch: The specification of the policy and value networks. - :param activation_fn: Activation function - :param ortho_init: Whether to use or not orthogonal initialization - :param use_sde: Whether to use State Dependent Exploration or not - :param log_std_init: Initial value for the log standard deviation - :param full_std: Whether to use (n_features x n_actions) parameters - for the std instead of only (n_features,) when using gSDE - :param sde_net_arch: Network architecture for extracting features - when using gSDE. If None, the latent features from the policy will be used. - Pass an empty list to use the states as features. - :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure - a positive standard deviation (cf paper). It allows to keep variance - above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. - :param squash_output: Whether to squash the output using a tanh function, - this allows to ensure boundaries when using gSDE. - :param features_extractor_class: Features extractor to use. - :param features_extractor_kwargs: Keyword arguments - to pass to the features extractor. - :param normalize_images: Whether to normalize images or not, - dividing by 255.0 (True by default) - :param optimizer_class: The optimizer to use, - ``th.optim.Adam`` by default - :param optimizer_kwargs: Additional keyword arguments, - excluding the learning rate, to pass to the optimizer - """ - - def __init__( - self, - observation_space: gym.spaces.Space, - action_space: gym.spaces.Space, - lr_schedule: Schedule, - net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, - activation_fn: Type[nn.Module] = nn.Tanh, - ortho_init: bool = True, - use_sde: bool = False, - log_std_init: float = 0.0, - full_std: bool = True, - sde_net_arch: Optional[List[int]] = None, - use_expln: bool = False, - squash_output: bool = False, - features_extractor_class: Type[BaseFeaturesExtractor] = NatureCNN, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, - normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - ): - super(ActorCriticCnnPolicy, self).__init__( - observation_space, - action_space, - lr_schedule, - net_arch, - activation_fn, - ortho_init, - use_sde, - log_std_init, - full_std, - sde_net_arch, - use_expln, - squash_output, - features_extractor_class, - features_extractor_kwargs, - normalize_images, - optimizer_class, - optimizer_kwargs, - ) + pass +# This is just to propagate get_distribution class MultiInputActorCriticPolicy(ActorCriticPolicy): - """ - MultiInputActorClass policy class for actor-critic algorithms (has both policy and value prediction). - Used by A2C, PPO and the likes. - - :param observation_space: Observation space (Tuple) - :param action_space: Action space - :param lr_schedule: Learning rate schedule (could be constant) - :param net_arch: The specification of the policy and value networks. - :param activation_fn: Activation function - :param ortho_init: Whether to use or not orthogonal initialization - :param use_sde: Whether to use State Dependent Exploration or not - :param log_std_init: Initial value for the log standard deviation - :param full_std: Whether to use (n_features x n_actions) parameters - for the std instead of only (n_features,) when using gSDE - :param sde_net_arch: Network architecture for extracting features - when using gSDE. If None, the latent features from the policy will be used. - Pass an empty list to use the states as features. - :param use_expln: Use ``expln()`` function instead of ``exp()`` to ensure - a positive standard deviation (cf paper). It allows to keep variance - above zero and prevent it from growing too fast. In practice, ``exp()`` is usually enough. - :param squash_output: Whether to squash the output using a tanh function, - this allows to ensure boundaries when using gSDE. - :param features_extractor_class: Uses the CombinedExtractor - :param features_extractor_kwargs: Keyword arguments - to pass to the feature extractor. - :param normalize_images: Whether to normalize images or not, - dividing by 255.0 (True by default) - :param optimizer_class: The optimizer to use, - ``th.optim.Adam`` by default - :param optimizer_kwargs: Additional keyword arguments, - excluding the learning rate, to pass to the optimizer - """ - - def __init__( - self, - observation_space: gym.spaces.Dict, - action_space: gym.spaces.Space, - lr_schedule: Schedule, - net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, - activation_fn: Type[nn.Module] = nn.Tanh, - ortho_init: bool = True, - use_sde: bool = False, - log_std_init: float = 0.0, - full_std: bool = True, - sde_net_arch: Optional[List[int]] = None, - use_expln: bool = False, - squash_output: bool = False, - features_extractor_class: Type[BaseFeaturesExtractor] = CombinedExtractor, - features_extractor_kwargs: Optional[Dict[str, Any]] = None, - normalize_images: bool = True, - optimizer_class: Type[th.optim.Optimizer] = th.optim.Adam, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - ): - super(MultiInputActorCriticPolicy, self).__init__( - observation_space, - action_space, - lr_schedule, - net_arch, - activation_fn, - ortho_init, - use_sde, - log_std_init, - full_std, - sde_net_arch, - use_expln, - squash_output, - features_extractor_class, - features_extractor_kwargs, - normalize_images, - optimizer_class, - optimizer_kwargs, - ) + pass From 347dcc076ac62cd897bc3eb233cd6060d021c3ea Mon Sep 17 00:00:00 2001 From: Cyprien Date: Sat, 11 Sep 2021 14:00:22 +0100 Subject: [PATCH 08/28] feat: using updated ActorCriticPolicy from SB3 - get_distribution will be added directly to the SB3 version of ActorCriticPolicy, this commit reflects this --- sb3_contrib/common/policies.py | 28 ---------------------------- sb3_contrib/trpo/policies.py | 9 ++++++--- sb3_contrib/trpo/trpo.py | 7 +++---- 3 files changed, 9 insertions(+), 35 deletions(-) delete mode 100644 sb3_contrib/common/policies.py diff --git a/sb3_contrib/common/policies.py b/sb3_contrib/common/policies.py deleted file mode 100644 index 75012f5d..00000000 --- a/sb3_contrib/common/policies.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Policies: abstract base class and concrete implementations.""" - -from stable_baselines3.common.distributions import Distribution -from stable_baselines3.common.policies import ActorCriticPolicy as _ActorCriticPolicy - - -class ActorCriticPolicy(_ActorCriticPolicy): - """ - Policy class for actor-critic algorithms (has both policy and value prediction). - Used by A2C, PPO and the likes. - """ - - def get_distribution(self) -> Distribution: - """ - Get the current action distribution - :return: Action distribution - """ - return self.action_dist - - -# This is just to propagate get_distribution -class ActorCriticCnnPolicy(ActorCriticPolicy): - pass - - -# This is just to propagate get_distribution -class MultiInputActorCriticPolicy(ActorCriticPolicy): - pass diff --git a/sb3_contrib/trpo/policies.py b/sb3_contrib/trpo/policies.py index 7aef8718..bc133a74 100644 --- a/sb3_contrib/trpo/policies.py +++ b/sb3_contrib/trpo/policies.py @@ -1,8 +1,11 @@ # This file is here just to define MlpPolicy/CnnPolicy # that work for TRPO -from sb3_contrib.common.policies import ActorCriticPolicy, ActorCriticCnnPolicy, MultiInputActorCriticPolicy -from stable_baselines3.common.policies import register_policy - +from stable_baselines3.common.policies import ( + register_policy, + ActorCriticPolicy, + ActorCriticCnnPolicy, + MultiInputActorCriticPolicy, +) MlpPolicy = ActorCriticPolicy CnnPolicy = ActorCriticCnnPolicy diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index 8b56f191..0e683d5e 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -10,9 +10,9 @@ from torch.distributions import kl_divergence from torch.nn import functional as F -from sb3_contrib.common.policies import ActorCriticPolicy from sb3_contrib.common.utils import flat_grad, conjugate_gradient_solver from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm +from stable_baselines3.common.policies import ActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance @@ -191,11 +191,10 @@ def train(self) -> None: self.policy.reset_noise(self.batch_size) with th.no_grad(): - _ = self.policy.evaluate_actions(rollout_data.observations, actions) - old_distribution = copy.copy(self.policy.get_distribution()) + old_distribution = copy.copy(self.policy.get_distribution(rollout_data.observations)) _, log_prob, _ = self.policy.evaluate_actions(rollout_data.observations, actions) - distribution = self.policy.get_distribution() + distribution = self.policy.action_dist advantages = rollout_data.advantages advantages = (advantages - advantages.mean()) / (rollout_data.advantages.std() + 1e-8) From 35d7256d462bb44cd82246194902564fbd78b0e3 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 13 Sep 2021 10:38:06 +0200 Subject: [PATCH 09/28] Bump version for `get_distribution` support --- docs/misc/changelog.rst | 29 ++++++++++++++++++++++++++++- sb3_contrib/version.txt | 2 +- setup.py | 2 +- 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 07a42a0a..7bf9e3ee 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,6 +3,28 @@ Changelog ========== +Release 1.2.1a0 () +------------------------------- + +**Add TRPO** + +Breaking Changes: +^^^^^^^^^^^^^^^^^ +- Upgraded to Stable-Baselines3 >= 1.2.1a0 + +Bug Fixes: +^^^^^^^^^^ + +Deprecations: +^^^^^^^^^^^^^ + +Others: +^^^^^^^ + + +Documentation: +^^^^^^^^^^^^^^ + Release 1.2.0 (2021-09-08) ------------------------------- @@ -13,6 +35,11 @@ Breaking Changes: ^^^^^^^^^^^^^^^^^ - Upgraded to Stable-Baselines3 >= 1.2.0 + +New Features: +^^^^^^^^^^^^^ +- Added ``TRPO`` (@cyprienc) + Bug Fixes: ^^^^^^^^^^ - QR-DQN and TQC updated so that their policies are switched between train and eval mode at the correct time (@ayeright) @@ -156,4 +183,4 @@ Stable-Baselines3 is currently maintained by `Antonin Raffin`_ (aka `@araffin`_) Contributors: ------------- -@ku2482 @guyk1971 @minhlong94 @ayeright +@ku2482 @guyk1971 @minhlong94 @ayeright @cyprienc diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index 26aaba0e..348e216f 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.2.0 +1.2.1a0 diff --git a/setup.py b/setup.py index 1b75a55c..2ebec276 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.2.0", + "stable_baselines3>=1.2.1a0", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", From 9cfcb540c9525c57789c17b2174c56c6384a1ba5 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 13 Sep 2021 10:38:27 +0200 Subject: [PATCH 10/28] Add basic test --- tests/test_run.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/test_run.py b/tests/test_run.py index 195d0114..5d1e47ef 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -1,6 +1,6 @@ import pytest -from sb3_contrib import QRDQN, TQC +from sb3_contrib import QRDQN, TQC, TRPO @pytest.mark.parametrize("ent_coef", ["auto", 0.01, "auto_0.01"]) @@ -56,3 +56,9 @@ def test_qrdqn(): create_eval_env=True, ) model.learn(total_timesteps=500, eval_freq=250) + + +@pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"]) +def test_trpo(env_id): + model = TRPO("MlpPolicy", env_id, n_steps=64, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1) + model.learn(total_timesteps=500) From 974174ac762634fcb8d4ae737b79b5e79a930c2a Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 13 Sep 2021 10:38:34 +0200 Subject: [PATCH 11/28] Reformat --- sb3_contrib/__init__.py | 1 - sb3_contrib/common/utils.py | 2 +- sb3_contrib/trpo/policies.py | 4 ++-- sb3_contrib/trpo/trpo.py | 12 ++++++------ 4 files changed, 9 insertions(+), 10 deletions(-) diff --git a/sb3_contrib/__init__.py b/sb3_contrib/__init__.py index 42b972d1..86291531 100644 --- a/sb3_contrib/__init__.py +++ b/sb3_contrib/__init__.py @@ -4,7 +4,6 @@ from sb3_contrib.tqc import TQC from sb3_contrib.trpo import TRPO - # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") with open(version_file, "r") as file_handler: diff --git a/sb3_contrib/common/utils.py b/sb3_contrib/common/utils.py index 2f4a9694..e07c190e 100644 --- a/sb3_contrib/common/utils.py +++ b/sb3_contrib/common/utils.py @@ -1,4 +1,4 @@ -from typing import Optional, Sequence, Callable +from typing import Callable, Optional, Sequence import torch as th from torch import nn diff --git a/sb3_contrib/trpo/policies.py b/sb3_contrib/trpo/policies.py index bc133a74..27cde537 100644 --- a/sb3_contrib/trpo/policies.py +++ b/sb3_contrib/trpo/policies.py @@ -1,10 +1,10 @@ # This file is here just to define MlpPolicy/CnnPolicy # that work for TRPO from stable_baselines3.common.policies import ( - register_policy, - ActorCriticPolicy, ActorCriticCnnPolicy, + ActorCriticPolicy, MultiInputActorCriticPolicy, + register_policy, ) MlpPolicy = ActorCriticPolicy diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index 0e683d5e..ad10077e 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -1,20 +1,20 @@ import copy import warnings from functools import partial -from typing import Any, Dict, Optional, Type, Union, List +from typing import Any, Dict, List, Optional, Type, Union import numpy as np import torch as th from gym import spaces -from torch import nn -from torch.distributions import kl_divergence -from torch.nn import functional as F - -from sb3_contrib.common.utils import flat_grad, conjugate_gradient_solver from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import ActorCriticPolicy from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule from stable_baselines3.common.utils import explained_variance +from torch import nn +from torch.distributions import kl_divergence +from torch.nn import functional as F + +from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad class TRPO(OnPolicyAlgorithm): From b6bd44992c6531693a4de70b77fc4d757df5cd03 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 13 Sep 2021 10:46:42 +0200 Subject: [PATCH 12/28] [ci skip] Fix changelog --- docs/misc/changelog.rst | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 7bf9e3ee..4ca31f07 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -12,6 +12,10 @@ Breaking Changes: ^^^^^^^^^^^^^^^^^ - Upgraded to Stable-Baselines3 >= 1.2.1a0 +New Features: +^^^^^^^^^^^^^ +- Added ``TRPO`` (@cyprienc) + Bug Fixes: ^^^^^^^^^^ @@ -36,10 +40,6 @@ Breaking Changes: - Upgraded to Stable-Baselines3 >= 1.2.0 -New Features: -^^^^^^^^^^^^^ -- Added ``TRPO`` (@cyprienc) - Bug Fixes: ^^^^^^^^^^ - QR-DQN and TQC updated so that their policies are switched between train and eval mode at the correct time (@ayeright) From c88951c48221973131a1da6b3d9743555c317c95 Mon Sep 17 00:00:00 2001 From: Cyprien Date: Mon, 13 Sep 2021 15:37:40 +0100 Subject: [PATCH 13/28] fix: setting train mode for trpo --- sb3_contrib/trpo/trpo.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index ad10077e..9bd01d41 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -165,6 +165,8 @@ def train(self) -> None: """ Update policy using the currently gathered rollout buffer. """ + # Switch to train mode (this affects batch norm / dropout) + self.policy.set_training_mode(True) # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) From 1f7e99db8b0e2ee612816b090a0574b0988d732d Mon Sep 17 00:00:00 2001 From: Cyprien Date: Mon, 13 Sep 2021 15:53:38 +0100 Subject: [PATCH 14/28] fix: batch_size type hint in trpo.py --- sb3_contrib/trpo/trpo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index 9bd01d41..b1acd025 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -73,7 +73,7 @@ def __init__( env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 2048, - batch_size: Optional[int] = 64, + batch_size: int = 64, n_epochs: int = 10, gamma: float = 0.99, cg_max_steps: int = 10, From 6540371e4c0a097590ffe69243d334c4f6c87eff Mon Sep 17 00:00:00 2001 From: Cyprien Date: Wed, 15 Sep 2021 09:04:22 +0100 Subject: [PATCH 15/28] style: renaming variables + docstring in trpo.py --- sb3_contrib/common/utils.py | 19 ++++++++++++------- sb3_contrib/trpo/trpo.py | 16 ++++++---------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/sb3_contrib/common/utils.py b/sb3_contrib/common/utils.py index e07c190e..8748d25b 100644 --- a/sb3_contrib/common/utils.py +++ b/sb3_contrib/common/utils.py @@ -80,7 +80,12 @@ def conjugate_gradient_solver( """ Finds an approximate solution to a set of linear equations Ax = b - Source: https://github.com/ajlangley/trpo-pytorch/blob/master/conjugate_gradient.py + Sources: + - https://github.com/ajlangley/trpo-pytorch/blob/master/conjugate_gradient.py + - https://github.com/joschu/modular_rl/blob/master/modular_rl/trpo.py#L122 + + Reference: + - https://epubs.siam.org/doi/abs/10.1137/1.9781611971446.ch6 :param matrix_vector_dot_func: a function that right multiplies a matrix A by a vector v @@ -98,8 +103,8 @@ def conjugate_gradient_solver( # The vector is not initialized at 0 because of the instability issues when the gradient becomes small. # A small random gaussian noise is used for the initialization. x = 1e-4 * th.randn_like(b) - r = b - matrix_vector_dot_func(x) - r_dot = th.matmul(r, r) + residual = b - matrix_vector_dot_func(x) + r_dot = th.matmul(residual, residual) if r_dot < residual_tol: # If the gradient becomes extremely small @@ -107,7 +112,7 @@ def conjugate_gradient_solver( # Leading to a division by zero return x - p = r.clone() + p = residual.clone() for i in range(max_iter): Avp = matrix_vector_dot_func(p) @@ -118,15 +123,15 @@ def conjugate_gradient_solver( if i == max_iter - 1: return x - r -= alpha * Avp - new_r_dot = th.matmul(r, r) + residual -= alpha * Avp + new_r_dot = th.matmul(residual, residual) if new_r_dot < residual_tol: return x beta = new_r_dot / r_dot r_dot = new_r_dot - p = r + beta * p + p = residual + beta * p # TODO: test diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index b1acd025..8ae2c788 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -42,11 +42,9 @@ class TRPO(OnPolicyAlgorithm): for computing the Hessian vector product :param cg_damping: damping in the Hessian vector product computation :param ls_alpha: step-size reduction factor for the line-search (i.e. theta_new = theta + alpha^i * step) - :param ls_steps: maximum number of steps in the line-search + :param line_search_max_steps: maximum number of steps in the line-search :param n_critic_updates: number of critic updates per policy updates :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator - :param ent_coef: Entropy coefficient for the loss calculation - :param vf_coef: Value function coefficient for the loss calculation :param max_grad_norm: The maximum value for the gradient clipping :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) @@ -79,11 +77,9 @@ def __init__( cg_max_steps: int = 10, cg_damping: float = 1e-3, ls_alpha: float = 0.8, - ls_steps: int = 10, + line_search_max_steps: int = 10, n_critic_updates: int = 5, gae_lambda: float = 0.95, - ent_coef: float = 0.0, - vf_coef: float = 0.5, max_grad_norm: float = 0.5, use_sde: bool = False, sde_sample_freq: int = -1, @@ -104,8 +100,8 @@ def __init__( n_steps=n_steps, gamma=gamma, gae_lambda=gae_lambda, - ent_coef=ent_coef, - vf_coef=vf_coef, + ent_coef=0.0, + vf_coef=0.0, max_grad_norm=max_grad_norm, use_sde=use_sde, sde_sample_freq=sde_sample_freq, @@ -154,7 +150,7 @@ def __init__( self.cg_max_steps = cg_max_steps self.cg_damping = cg_damping self.ls_alpha = ls_alpha - self.ls_steps = ls_steps + self.line_search_max_steps = line_search_max_steps self.target_kl = target_kl self.n_critic_updates = n_critic_updates @@ -273,7 +269,7 @@ def train(self) -> None: is_line_search_success = False with th.no_grad(): # Line-search - for _ in range(self.ls_steps): + for _ in range(self.line_search_max_steps): j = 0 # Applying the scaled step direction From 8ecf40e16d906235175a5d9705963904ee823144 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 29 Sep 2021 11:42:24 +0200 Subject: [PATCH 16/28] Rename + cleanup --- docs/misc/changelog.rst | 2 +- sb3_contrib/trpo/trpo.py | 115 ++++++++++++++++++++++----------------- setup.cfg | 1 + setup.py | 2 +- 4 files changed, 68 insertions(+), 52 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index c66114cb..2dab71e8 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -10,7 +10,7 @@ Release 1.2.1a2 (WIP) Breaking Changes: ^^^^^^^^^^^^^^^^^ -- Upgraded to Stable-Baselines3 >= 1.2.1a0 +- Upgraded to Stable-Baselines3 >= 1.2.1a2 - Removed ``sde_net_arch`` New Features: diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index 8ae2c788..6b00edb4 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -41,7 +41,8 @@ class TRPO(OnPolicyAlgorithm): :param cg_max_steps: maximum number of steps in the Conjugate Gradient algoritgm for computing the Hessian vector product :param cg_damping: damping in the Hessian vector product computation - :param ls_alpha: step-size reduction factor for the line-search (i.e. theta_new = theta + alpha^i * step) + :param line_search_shrinking_factor: step-size reduction factor for the line-search + (i.e., ``theta_new = theta + alpha^i * step``) :param line_search_max_steps: maximum number of steps in the line-search :param n_critic_updates: number of critic updates per policy updates :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator @@ -76,7 +77,7 @@ def __init__( gamma: float = 0.99, cg_max_steps: int = 10, cg_damping: float = 1e-3, - ls_alpha: float = 0.8, + line_search_shrinking_factor: float = 0.8, line_search_max_steps: int = 10, n_critic_updates: int = 5, gae_lambda: float = 0.95, @@ -149,7 +150,7 @@ def __init__( self.n_epochs = n_epochs self.cg_max_steps = cg_max_steps self.cg_damping = cg_damping - self.ls_alpha = ls_alpha + self.line_search_shrinking_factor = line_search_shrinking_factor self.line_search_max_steps = line_search_max_steps self.target_kl = target_kl self.n_critic_updates = n_critic_updates @@ -166,15 +167,16 @@ def train(self) -> None: # Update optimizer learning rate self._update_learning_rate(self.policy.optimizer) - po_values = [] + policy_objective_values = [] kl_divergences = [] line_search_results = [] value_losses = [] + # Note(antonin): this value is never changed continue_training = True # train for n_epochs epochs - for epoch in range(self.n_epochs): + for _ in range(self.n_epochs): # Do a complete pass on the rollout buffer for rollout_data in self.rollout_buffer.get(self.batch_size): actions = rollout_data.actions @@ -183,16 +185,14 @@ def train(self) -> None: actions = rollout_data.actions.long().flatten() # Re-sample the noise matrix because the log_std has changed - # TODO: investigate why there is no issue with the gradient - # if that line is commented (as in SAC) if self.use_sde: self.policy.reset_noise(self.batch_size) with th.no_grad(): - old_distribution = copy.copy(self.policy.get_distribution(rollout_data.observations)) + old_distribution = copy.deepcopy(self.policy.get_distribution(rollout_data.observations)) - _, log_prob, _ = self.policy.evaluate_actions(rollout_data.observations, actions) - distribution = self.policy.action_dist + distribution = self.policy.get_distribution(rollout_data.observations) + log_prob = distribution.log_prob(actions) advantages = rollout_data.advantages advantages = (advantages - advantages.mean()) / (rollout_data.advantages.std() + 1e-8) @@ -201,7 +201,7 @@ def train(self) -> None: ratio = th.exp(log_prob - rollout_data.old_log_prob) # surrogate policy objective - policy_obj = (advantages * ratio).mean() + policy_objective = (advantages * ratio).mean() # KL divergence kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean() @@ -210,7 +210,8 @@ def train(self) -> None: self.policy.optimizer.zero_grad() # This is necessary because not all the parameters in the policy have gradients w.r.t. the KL divergence - policy_obj_gradient = [] + # The policy objective is also called surrogate objective + policy_objective_gradients = [] # Contains the gradients of the KL divergence grad_kl = [] # Contains the shape of the gradients of the KL divergence w.r.t each parameter @@ -219,9 +220,14 @@ def train(self) -> None: grad_shape = [] # Contains the parameters which have non-zeros KL divergence gradients # The list is used during the line-search to apply the step to each parameters - params = [] + actor_params = [] + + for name, param in self.policy.named_parameters(): + # Skip parameters related to value function based on name + # this work for built-in policies only (not custom ones) + if "value" in name: + continue - for param in self.policy.parameters(): # For each parameter we compute the gradient of the KL divergence w.r.t to that parameter kl_param_grad, *_ = th.autograd.grad( kl_div, @@ -231,96 +237,104 @@ def train(self) -> None: allow_unused=True, only_inputs=True, ) - # If the gradient is not zero (not None), we store the parameter in the params list + # If the gradient is not zero (not None), we store the parameter in the actor_params list # and add the gradient and its shape to grad_kl and grad_shape respectively if kl_param_grad is not None: # If the parameter impacts the KL divergence (i.e. the policy) # we compute the gradient of the policy objective w.r.t to the parameter # this avoids computing the gradient if it's not going to be used in the conjugate gradient step - g_grad, *_ = th.autograd.grad(policy_obj, param, retain_graph=True, only_inputs=True) + policy_objective_grad, *_ = th.autograd.grad( + policy_objective, param, retain_graph=True, only_inputs=True + ) grad_shape.append(kl_param_grad.shape) grad_kl.append(kl_param_grad.view(-1)) - policy_obj_gradient.append(g_grad.view(-1)) - params.append(param) + policy_objective_gradients.append(policy_objective_grad.view(-1)) + actor_params.append(param) # Gradients are concatenated before the conjugate gradient step - policy_obj_gradient = th.cat(policy_obj_gradient) + policy_objective_gradients = th.cat(policy_objective_gradients) grad_kl = th.cat(grad_kl) # Hessian-vector dot product function used in the conjugate gradient step - hvp = partial(self.hessian_vector_product, params, grad_kl) + hessian_vector_product = partial(self.hessian_vector_product, actor_params, grad_kl) # Computing search direction search_direction = conjugate_gradient_solver( - hvp, - policy_obj_gradient, + hessian_vector_product, + policy_objective_gradients, max_iter=self.cg_max_steps, ) # Maximal step length - beta = 2 * self.target_kl - beta /= th.matmul(search_direction, hvp(search_direction, retain_graph=False)) - beta = th.sqrt(beta) + line_search_max_step_size = 2 * self.target_kl + line_search_max_step_size /= th.matmul( + search_direction, hessian_vector_product(search_direction, retain_graph=False) + ) + line_search_max_step_size = th.sqrt(line_search_max_step_size) - alpha = 1 - orig_params = [param.detach().clone() for param in params] + line_search_backtrack_coeff = 1.0 + original_actor_params = [param.detach().clone() for param in actor_params] is_line_search_success = False with th.no_grad(): - # Line-search + # Line-search (backtracking) for _ in range(self.line_search_max_steps): - j = 0 + start_idx = 0 # Applying the scaled step direction - for param, orig_param, shape in zip(params, orig_params, grad_shape): - k = param.numel() - param.data = orig_param.data + alpha * beta * search_direction[j : (j + k)].view(shape) - j += k + for param, original_param, shape in zip(actor_params, original_actor_params, grad_shape): + n_params = param.numel() + param.data = ( + original_param.data + + line_search_backtrack_coeff + * line_search_max_step_size + * search_direction[start_idx : (start_idx + n_params)].view(shape) + ) + start_idx += n_params # Recomputing the policy log-probabilities _, log_prob, _ = self.policy.evaluate_actions(rollout_data.observations, actions) # New policy objective ratio = th.exp(log_prob - rollout_data.old_log_prob) - new_policy_obj = (advantages * ratio).mean() + new_policy_objective = (advantages * ratio).mean() # New KL-divergence kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean() # Constraint criteria - if (kl_div < self.target_kl) and (new_policy_obj > policy_obj): + if (kl_div < self.target_kl) and (new_policy_objective > policy_objective): is_line_search_success = True break # Reducing step size if line-search wasn't successful - alpha *= self.ls_alpha + line_search_backtrack_coeff *= self.line_search_shrinking_factor line_search_results.append(is_line_search_success) if not is_line_search_success: # If the line-search wasn't successful we revert to the original parameters - for param, orig_param in zip(params, orig_params): - param.data = orig_param.data.clone() + for param, original_param in zip(actor_params, original_actor_params): + param.data = original_param.data.clone() - po_values.append(policy_obj.item()) + policy_objective_values.append(policy_objective.item()) kl_divergences.append(0) else: - po_values.append(new_policy_obj.item()) + policy_objective_values.append(new_policy_objective.item()) kl_divergences.append(kl_div.item()) # Critic updates for _ in range(self.n_critic_updates): - values, _, _ = self.policy.evaluate_actions(rollout_data.observations, actions) - values_pred = values.flatten() - value_loss = F.mse_loss(rollout_data.returns, values_pred) + values_pred = self.policy.predict_values(rollout_data.observations) + value_loss = F.mse_loss(rollout_data.returns, values_pred.flatten()) value_losses.append(value_loss.item()) self.policy.optimizer.zero_grad() value_loss.backward() # Removing gradients of parameters shared with the actor # otherwise it defeats the purposes of the KL constraint - for param in params: + for param in actor_params: param.grad = None # Clip grad norm th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) @@ -333,7 +347,7 @@ def train(self) -> None: explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) # Logs - self.logger.record("train/policy_objective_value", np.mean(po_values)) + self.logger.record("train/policy_objective", np.mean(policy_objective_values)) self.logger.record("train/value_loss", np.mean(value_losses)) self.logger.record("train/kl_divergence_loss", np.mean(kl_divergences)) self.logger.record("train/explained_variance", explained_var) @@ -344,18 +358,19 @@ def train(self) -> None: self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard") def hessian_vector_product( - self, params: List[nn.Parameter], grad_kl: th.Tensor, v: th.Tensor, retain_graph: bool = True + self, params: List[nn.Parameter], grad_kl: th.Tensor, vector: th.Tensor, retain_graph: bool = True ) -> th.Tensor: """ - Computes the matrix-vector product with the Fisher information matrix + Computes the matrix-vector product with the Fisher information matrix. + :param params: list of parameters used to compute the Hessian :param grad_kl: flattened gradient of the KL divergence between the old and new policy - :param v: vector to compute the dot product the hessian-vector dot product with + :param vector: vector to compute the dot product the hessian-vector dot product with :param retain_graph: if True, the graph will be kept after computing the Hessian :return: Hessian-vector dot product """ - jvp = (grad_kl * v).sum() - return flat_grad(jvp, params, retain_graph=retain_graph) + self.cg_damping * v + jacobian_vector_product = (grad_kl * vector).sum() + return flat_grad(jacobian_vector_product, params, retain_graph=retain_graph) + self.cg_damping * vector def learn( self, diff --git a/setup.cfg b/setup.cfg index cf162f7b..dc9cf363 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,6 +25,7 @@ per-file-ignores = ./sb3_contrib/ppo_mask/__init__.py:F401 ./sb3_contrib/qrdqn/__init__.py:F401 ./sb3_contrib/tqc/__init__.py:F401 + ./sb3_contrib/trpo/__init__.py:F401 ./sb3_contrib/common/vec_env/wrappers/__init__.py:F401 ./sb3_contrib/common/wrappers/__init__.py:F401 ./sb3_contrib/common/envs/__init__.py:F401 diff --git a/setup.py b/setup.py index 2ebec276..871e5bbd 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ packages=[package for package in find_packages() if package.startswith("sb3_contrib")], package_data={"sb3_contrib": ["py.typed", "version.txt"]}, install_requires=[ - "stable_baselines3>=1.2.1a0", + "stable_baselines3>=1.2.1a2", ], description="Contrib package of Stable Baselines3, experimental code.", author="Antonin Raffin", From 45f4ea6896ef68049ad9388e78cacb8ef76f5d19 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 29 Sep 2021 11:58:46 +0200 Subject: [PATCH 17/28] Move grad computation to separate method --- sb3_contrib/trpo/trpo.py | 117 ++++++++++++++++++++++----------------- 1 file changed, 67 insertions(+), 50 deletions(-) diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index 6b00edb4..31bec33f 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -1,7 +1,7 @@ import copy import warnings from functools import partial -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Tuple, Type, Union import numpy as np import torch as th @@ -38,7 +38,7 @@ class TRPO(OnPolicyAlgorithm): :param batch_size: Minibatch size :param n_epochs: Number of epoch when optimizing the surrogate loss :param gamma: Discount factor - :param cg_max_steps: maximum number of steps in the Conjugate Gradient algoritgm + :param cg_max_steps: maximum number of steps in the Conjugate Gradient algorithm for computing the Hessian vector product :param cg_damping: damping in the Hessian vector product computation :param line_search_shrinking_factor: step-size reduction factor for the line-search @@ -158,6 +158,62 @@ def __init__( if _init_setup_model: self._setup_model() + def _compute_actor_grad( + self, kl_div: th.Tensor, policy_objective: th.Tensor + ) -> Tuple[List[nn.Parameter], th.Tensor, th.Tensor, List[Tuple[int, ...]]]: + """ + Compute actor gradients for kl div and surrogate objectives. + + :param kl_div: The KL div objective + :param policy_objective: The surrogate objective ("classic" policy gradient) + :return: List of actor params, gradients and gradients shape. + """ + # This is necessary because not all the parameters in the policy have gradients w.r.t. the KL divergence + # The policy objective is also called surrogate objective + policy_objective_gradients = [] + # Contains the gradients of the KL divergence + grad_kl = [] + # Contains the shape of the gradients of the KL divergence w.r.t each parameter + # This way the flattened gradient can be reshaped back into the original shapes and applied to + # the parameters + grad_shape = [] + # Contains the parameters which have non-zeros KL divergence gradients + # The list is used during the line-search to apply the step to each parameters + actor_params = [] + + for name, param in self.policy.named_parameters(): + # Skip parameters related to value function based on name + # this work for built-in policies only (not custom ones) + if "value" in name: + continue + + # For each parameter we compute the gradient of the KL divergence w.r.t to that parameter + kl_param_grad, *_ = th.autograd.grad( + kl_div, + param, + create_graph=True, + retain_graph=True, + allow_unused=True, + only_inputs=True, + ) + # If the gradient is not zero (not None), we store the parameter in the actor_params list + # and add the gradient and its shape to grad_kl and grad_shape respectively + if kl_param_grad is not None: + # If the parameter impacts the KL divergence (i.e. the policy) + # we compute the gradient of the policy objective w.r.t to the parameter + # this avoids computing the gradient if it's not going to be used in the conjugate gradient step + policy_objective_grad, *_ = th.autograd.grad(policy_objective, param, retain_graph=True, only_inputs=True) + + grad_shape.append(kl_param_grad.shape) + grad_kl.append(kl_param_grad.view(-1)) + policy_objective_gradients.append(policy_objective_grad.view(-1)) + actor_params.append(param) + + # Gradients are concatenated before the conjugate gradient step + policy_objective_gradients = th.cat(policy_objective_gradients) + grad_kl = th.cat(grad_kl) + return actor_params, policy_objective_gradients, grad_kl, grad_shape + def train(self) -> None: """ Update policy using the currently gathered rollout buffer. @@ -189,7 +245,10 @@ def train(self) -> None: self.policy.reset_noise(self.batch_size) with th.no_grad(): - old_distribution = copy.deepcopy(self.policy.get_distribution(rollout_data.observations)) + # Note: is copy enough, no need for deepcopy? + # If using gSDE and deepcopy, we need to use `old_distribution.distribution` + # directly to avoid PyTorch errors. + old_distribution = copy.copy(self.policy.get_distribution(rollout_data.observations)) distribution = self.policy.get_distribution(rollout_data.observations) log_prob = distribution.log_prob(actions) @@ -209,52 +268,9 @@ def train(self) -> None: # Surrogate & KL gradient self.policy.optimizer.zero_grad() - # This is necessary because not all the parameters in the policy have gradients w.r.t. the KL divergence - # The policy objective is also called surrogate objective - policy_objective_gradients = [] - # Contains the gradients of the KL divergence - grad_kl = [] - # Contains the shape of the gradients of the KL divergence w.r.t each parameter - # This way the flattened gradient can be reshaped back into the original shapes and applied to - # the parameters - grad_shape = [] - # Contains the parameters which have non-zeros KL divergence gradients - # The list is used during the line-search to apply the step to each parameters - actor_params = [] - - for name, param in self.policy.named_parameters(): - # Skip parameters related to value function based on name - # this work for built-in policies only (not custom ones) - if "value" in name: - continue - - # For each parameter we compute the gradient of the KL divergence w.r.t to that parameter - kl_param_grad, *_ = th.autograd.grad( - kl_div, - param, - create_graph=True, - retain_graph=True, - allow_unused=True, - only_inputs=True, - ) - # If the gradient is not zero (not None), we store the parameter in the actor_params list - # and add the gradient and its shape to grad_kl and grad_shape respectively - if kl_param_grad is not None: - # If the parameter impacts the KL divergence (i.e. the policy) - # we compute the gradient of the policy objective w.r.t to the parameter - # this avoids computing the gradient if it's not going to be used in the conjugate gradient step - policy_objective_grad, *_ = th.autograd.grad( - policy_objective, param, retain_graph=True, only_inputs=True - ) - - grad_shape.append(kl_param_grad.shape) - grad_kl.append(kl_param_grad.view(-1)) - policy_objective_gradients.append(policy_objective_grad.view(-1)) - actor_params.append(param) - - # Gradients are concatenated before the conjugate gradient step - policy_objective_gradients = th.cat(policy_objective_gradients) - grad_kl = th.cat(grad_kl) + actor_params, policy_objective_gradients, grad_kl, grad_shape = self._compute_actor_grad( + kl_div, policy_objective + ) # Hessian-vector dot product function used in the conjugate gradient step hessian_vector_product = partial(self.hessian_vector_product, actor_params, grad_kl) @@ -294,7 +310,8 @@ def train(self) -> None: start_idx += n_params # Recomputing the policy log-probabilities - _, log_prob, _ = self.policy.evaluate_actions(rollout_data.observations, actions) + distribution = self.policy.get_distribution(rollout_data.observations) + log_prob = distribution.log_prob(actions) # New policy objective ratio = th.exp(log_prob - rollout_data.old_log_prob) From cc4b5ab8058fadd5cd3022dd90021ca31ffb1093 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 29 Sep 2021 14:41:18 +0200 Subject: [PATCH 18/28] Remove grad norm clipping --- sb3_contrib/trpo/trpo.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index 31bec33f..e35846dc 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -29,7 +29,7 @@ class TRPO(OnPolicyAlgorithm): :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) :param env: The environment to learn from (if registered in Gym, can be str) - :param learning_rate: The learning rate, it can be a function + :param learning_rate: The learning rate for the value function, it can be a function of the current progress remaining (from 1 to 0) :param n_steps: The number of steps to run for each environment per update (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel) @@ -46,7 +46,6 @@ class TRPO(OnPolicyAlgorithm): :param line_search_max_steps: maximum number of steps in the line-search :param n_critic_updates: number of critic updates per policy updates :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator - :param max_grad_norm: The maximum value for the gradient clipping :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE @@ -81,7 +80,6 @@ def __init__( line_search_max_steps: int = 10, n_critic_updates: int = 5, gae_lambda: float = 0.95, - max_grad_norm: float = 0.5, use_sde: bool = False, sde_sample_freq: int = -1, target_kl: float = 0.01, @@ -101,9 +99,9 @@ def __init__( n_steps=n_steps, gamma=gamma, gae_lambda=gae_lambda, - ent_coef=0.0, + ent_coef=0.0, # TODO: add entropy bonus to surrogate objective vf_coef=0.0, - max_grad_norm=max_grad_norm, + max_grad_norm=0.0, use_sde=use_sde, sde_sample_freq=sde_sample_freq, policy_base=ActorCriticPolicy, @@ -353,8 +351,6 @@ def train(self) -> None: # otherwise it defeats the purposes of the KL constraint for param in actor_params: param.grad = None - # Clip grad norm - th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) self.policy.optimizer.step() if not continue_training: From fc7a6c72db38d2c63d3e07bf41f9f818d6233a2d Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 29 Sep 2021 15:27:29 +0200 Subject: [PATCH 19/28] Remove n epochs and add sub-sampling --- sb3_contrib/trpo/trpo.py | 256 +++++++++++++++++++-------------------- 1 file changed, 128 insertions(+), 128 deletions(-) diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index e35846dc..13838ddf 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -8,7 +8,7 @@ from gym import spaces from stable_baselines3.common.on_policy_algorithm import OnPolicyAlgorithm from stable_baselines3.common.policies import ActorCriticPolicy -from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule +from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, RolloutBufferSamples, Schedule from stable_baselines3.common.utils import explained_variance from torch import nn from torch.distributions import kl_divergence @@ -35,8 +35,7 @@ class TRPO(OnPolicyAlgorithm): (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel) NOTE: n_steps * n_envs must be greater than 1 (because of the advantage normalization) See https://github.com/pytorch/pytorch/issues/29372 - :param batch_size: Minibatch size - :param n_epochs: Number of epoch when optimizing the surrogate loss + :param batch_size: Minibatch size for the value function :param gamma: Discount factor :param cg_max_steps: maximum number of steps in the Conjugate Gradient algorithm for computing the Hessian vector product @@ -54,6 +53,8 @@ class TRPO(OnPolicyAlgorithm): because the clipping is not enough to prevent large update see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) By default, there is no limit on the kl div. + :param sub_sampling_factor: Sub-sample the batch to make computation faster + see p40-42 of John Schulman thesis http://joschu.net/docs/thesis.pdf :param tensorboard_log: the log location for tensorboard (if None, no logging) :param create_eval_env: Whether to create a second environment that will be used for evaluating the agent periodically. (Only available when passing string for the environment) @@ -71,8 +72,7 @@ def __init__( env: Union[GymEnv, str], learning_rate: Union[float, Schedule] = 3e-4, n_steps: int = 2048, - batch_size: int = 64, - n_epochs: int = 10, + batch_size: int = 128, gamma: float = 0.99, cg_max_steps: int = 10, cg_damping: float = 1e-3, @@ -83,6 +83,7 @@ def __init__( use_sde: bool = False, sde_sample_freq: int = -1, target_kl: float = 0.01, + sub_sampling_factor: int = 1, tensorboard_log: Optional[str] = None, create_eval_env: bool = False, policy_kwargs: Optional[Dict[str, Any]] = None, @@ -122,10 +123,6 @@ def __init__( # Sanity check, otherwise it will lead to noisy gradient and NaN # because of the advantage normalization - assert ( - batch_size > 1 - ), "`batch_size` must be greater than 1. See https://github.com/DLR-RM/stable-baselines3/issues/440" - if self.env is not None: # Check that `n_steps * n_envs > 1` to avoid NaN # when doing advantage normalization @@ -145,13 +142,13 @@ def __init__( f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})" ) self.batch_size = batch_size - self.n_epochs = n_epochs self.cg_max_steps = cg_max_steps self.cg_damping = cg_damping self.line_search_shrinking_factor = line_search_shrinking_factor self.line_search_max_steps = line_search_max_steps self.target_kl = target_kl self.n_critic_updates = n_critic_updates + self.sub_sampling_factor = sub_sampling_factor if _init_setup_model: self._setup_model() @@ -226,137 +223,140 @@ def train(self) -> None: line_search_results = [] value_losses = [] - # Note(antonin): this value is never changed - continue_training = True + # This will only loop once (get all data in one go) + for rollout_data in self.rollout_buffer.get(batch_size=None): + + # Optional: sub-sample data for faster computation + if self.sub_sampling_factor > 1: + rollout_data = RolloutBufferSamples( + rollout_data.observations[:: self.sub_sampling_factor], + rollout_data.actions[:: self.sub_sampling_factor], + None, # old values, not used here + rollout_data.old_log_prob[:: self.sub_sampling_factor], + rollout_data.advantages[:: self.sub_sampling_factor], + None, # returns, not used here + ) - # train for n_epochs epochs - for _ in range(self.n_epochs): - # Do a complete pass on the rollout buffer - for rollout_data in self.rollout_buffer.get(self.batch_size): - actions = rollout_data.actions - if isinstance(self.action_space, spaces.Discrete): - # Convert discrete action from float to long - actions = rollout_data.actions.long().flatten() + actions = rollout_data.actions + if isinstance(self.action_space, spaces.Discrete): + # Convert discrete action from float to long + actions = rollout_data.actions.long().flatten() - # Re-sample the noise matrix because the log_std has changed - if self.use_sde: - self.policy.reset_noise(self.batch_size) + # Re-sample the noise matrix because the log_std has changed + if self.use_sde: + self.policy.reset_noise(self.batch_size) - with th.no_grad(): - # Note: is copy enough, no need for deepcopy? - # If using gSDE and deepcopy, we need to use `old_distribution.distribution` - # directly to avoid PyTorch errors. - old_distribution = copy.copy(self.policy.get_distribution(rollout_data.observations)) + with th.no_grad(): + # Note: is copy enough, no need for deepcopy? + # If using gSDE and deepcopy, we need to use `old_distribution.distribution` + # directly to avoid PyTorch errors. + old_distribution = copy.copy(self.policy.get_distribution(rollout_data.observations)) - distribution = self.policy.get_distribution(rollout_data.observations) - log_prob = distribution.log_prob(actions) + distribution = self.policy.get_distribution(rollout_data.observations) + log_prob = distribution.log_prob(actions) - advantages = rollout_data.advantages - advantages = (advantages - advantages.mean()) / (rollout_data.advantages.std() + 1e-8) + advantages = rollout_data.advantages + advantages = (advantages - advantages.mean()) / (rollout_data.advantages.std() + 1e-8) - # ratio between old and new policy, should be one at the first iteration - ratio = th.exp(log_prob - rollout_data.old_log_prob) + # ratio between old and new policy, should be one at the first iteration + ratio = th.exp(log_prob - rollout_data.old_log_prob) - # surrogate policy objective - policy_objective = (advantages * ratio).mean() + # surrogate policy objective + policy_objective = (advantages * ratio).mean() - # KL divergence - kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean() + # KL divergence + kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean() - # Surrogate & KL gradient - self.policy.optimizer.zero_grad() + # Surrogate & KL gradient + self.policy.optimizer.zero_grad() - actor_params, policy_objective_gradients, grad_kl, grad_shape = self._compute_actor_grad( - kl_div, policy_objective - ) + actor_params, policy_objective_gradients, grad_kl, grad_shape = self._compute_actor_grad(kl_div, policy_objective) - # Hessian-vector dot product function used in the conjugate gradient step - hessian_vector_product = partial(self.hessian_vector_product, actor_params, grad_kl) + # Hessian-vector dot product function used in the conjugate gradient step + hessian_vector_product = partial(self.hessian_vector_product, actor_params, grad_kl) - # Computing search direction - search_direction = conjugate_gradient_solver( - hessian_vector_product, - policy_objective_gradients, - max_iter=self.cg_max_steps, - ) + # Computing search direction + search_direction = conjugate_gradient_solver( + hessian_vector_product, + policy_objective_gradients, + max_iter=self.cg_max_steps, + ) - # Maximal step length - line_search_max_step_size = 2 * self.target_kl - line_search_max_step_size /= th.matmul( - search_direction, hessian_vector_product(search_direction, retain_graph=False) - ) - line_search_max_step_size = th.sqrt(line_search_max_step_size) - - line_search_backtrack_coeff = 1.0 - original_actor_params = [param.detach().clone() for param in actor_params] - - is_line_search_success = False - with th.no_grad(): - # Line-search (backtracking) - for _ in range(self.line_search_max_steps): - - start_idx = 0 - # Applying the scaled step direction - for param, original_param, shape in zip(actor_params, original_actor_params, grad_shape): - n_params = param.numel() - param.data = ( - original_param.data - + line_search_backtrack_coeff - * line_search_max_step_size - * search_direction[start_idx : (start_idx + n_params)].view(shape) - ) - start_idx += n_params - - # Recomputing the policy log-probabilities - distribution = self.policy.get_distribution(rollout_data.observations) - log_prob = distribution.log_prob(actions) - - # New policy objective - ratio = th.exp(log_prob - rollout_data.old_log_prob) - new_policy_objective = (advantages * ratio).mean() - - # New KL-divergence - kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean() - - # Constraint criteria - if (kl_div < self.target_kl) and (new_policy_objective > policy_objective): - is_line_search_success = True - break - - # Reducing step size if line-search wasn't successful - line_search_backtrack_coeff *= self.line_search_shrinking_factor - - line_search_results.append(is_line_search_success) - - if not is_line_search_success: - # If the line-search wasn't successful we revert to the original parameters - for param, original_param in zip(actor_params, original_actor_params): - param.data = original_param.data.clone() - - policy_objective_values.append(policy_objective.item()) - kl_divergences.append(0) - else: - policy_objective_values.append(new_policy_objective.item()) - kl_divergences.append(kl_div.item()) - - # Critic updates - for _ in range(self.n_critic_updates): - values_pred = self.policy.predict_values(rollout_data.observations) - value_loss = F.mse_loss(rollout_data.returns, values_pred.flatten()) - value_losses.append(value_loss.item()) - - self.policy.optimizer.zero_grad() - value_loss.backward() - # Removing gradients of parameters shared with the actor - # otherwise it defeats the purposes of the KL constraint - for param in actor_params: - param.grad = None - self.policy.optimizer.step() - - if not continue_training: - break - - self._n_updates += self.n_epochs + # Maximal step length + line_search_max_step_size = 2 * self.target_kl + line_search_max_step_size /= th.matmul( + search_direction, hessian_vector_product(search_direction, retain_graph=False) + ) + line_search_max_step_size = th.sqrt(line_search_max_step_size) + + line_search_backtrack_coeff = 1.0 + original_actor_params = [param.detach().clone() for param in actor_params] + + is_line_search_success = False + with th.no_grad(): + # Line-search (backtracking) + for _ in range(self.line_search_max_steps): + + start_idx = 0 + # Applying the scaled step direction + for param, original_param, shape in zip(actor_params, original_actor_params, grad_shape): + n_params = param.numel() + param.data = ( + original_param.data + + line_search_backtrack_coeff + * line_search_max_step_size + * search_direction[start_idx : (start_idx + n_params)].view(shape) + ) + start_idx += n_params + + # Recomputing the policy log-probabilities + distribution = self.policy.get_distribution(rollout_data.observations) + log_prob = distribution.log_prob(actions) + + # New policy objective + ratio = th.exp(log_prob - rollout_data.old_log_prob) + new_policy_objective = (advantages * ratio).mean() + + # New KL-divergence + kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean() + + # Constraint criteria + if (kl_div < self.target_kl) and (new_policy_objective > policy_objective): + is_line_search_success = True + break + + # Reducing step size if line-search wasn't successful + line_search_backtrack_coeff *= self.line_search_shrinking_factor + + line_search_results.append(is_line_search_success) + + if not is_line_search_success: + # If the line-search wasn't successful we revert to the original parameters + for param, original_param in zip(actor_params, original_actor_params): + param.data = original_param.data.clone() + + policy_objective_values.append(policy_objective.item()) + kl_divergences.append(0) + else: + policy_objective_values.append(new_policy_objective.item()) + kl_divergences.append(kl_div.item()) + + # Critic update + for _ in range(self.n_critic_updates): + for rollout_data in self.rollout_buffer.get(self.batch_size): + values_pred = self.policy.predict_values(rollout_data.observations) + value_loss = F.mse_loss(rollout_data.returns, values_pred.flatten()) + value_losses.append(value_loss.item()) + + self.policy.optimizer.zero_grad() + value_loss.backward() + # Removing gradients of parameters shared with the actor + # otherwise it defeats the purposes of the KL constraint + for param in actor_params: + param.grad = None + self.policy.optimizer.step() + + self._n_updates += 1 explained_var = explained_variance(self.rollout_buffer.values.flatten(), self.rollout_buffer.returns.flatten()) # Logs From 66723ff2dffe55039610eb1111f0ae99322add4e Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 29 Sep 2021 17:09:35 +0200 Subject: [PATCH 20/28] Update defaults --- sb3_contrib/trpo/trpo.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index 13838ddf..aca8f9f5 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -42,7 +42,8 @@ class TRPO(OnPolicyAlgorithm): :param cg_damping: damping in the Hessian vector product computation :param line_search_shrinking_factor: step-size reduction factor for the line-search (i.e., ``theta_new = theta + alpha^i * step``) - :param line_search_max_steps: maximum number of steps in the line-search + :param line_search_max_iter: maximum number of iteration + for the backtracking line-search :param n_critic_updates: number of critic updates per policy updates :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) @@ -70,15 +71,15 @@ def __init__( self, policy: Union[str, Type[ActorCriticPolicy]], env: Union[GymEnv, str], - learning_rate: Union[float, Schedule] = 3e-4, + learning_rate: Union[float, Schedule] = 1e-3, n_steps: int = 2048, batch_size: int = 128, gamma: float = 0.99, - cg_max_steps: int = 10, - cg_damping: float = 1e-3, + cg_max_steps: int = 15, + cg_damping: float = 0.1, line_search_shrinking_factor: float = 0.8, - line_search_max_steps: int = 10, - n_critic_updates: int = 5, + line_search_max_iter: int = 10, + n_critic_updates: int = 10, gae_lambda: float = 0.95, use_sde: bool = False, sde_sample_freq: int = -1, @@ -145,7 +146,7 @@ def __init__( self.cg_max_steps = cg_max_steps self.cg_damping = cg_damping self.line_search_shrinking_factor = line_search_shrinking_factor - self.line_search_max_steps = line_search_max_steps + self.line_search_max_iter = line_search_max_iter self.target_kl = target_kl self.n_critic_updates = n_critic_updates self.sub_sampling_factor = sub_sampling_factor @@ -295,7 +296,7 @@ def train(self) -> None: is_line_search_success = False with th.no_grad(): # Line-search (backtracking) - for _ in range(self.line_search_max_steps): + for _ in range(self.line_search_max_iter): start_idx = 0 # Applying the scaled step direction From e9833485e5fcdcbe1b83fe62c0fc4833b07648bd Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 27 Dec 2021 16:03:31 +0100 Subject: [PATCH 21/28] Add Doc --- README.md | 1 + docs/guide/algos.rst | 3 +- docs/index.rst | 1 + docs/modules/trpo.rst | 151 +++++++++++++++++++++++++++++++++++++++ sb3_contrib/trpo/trpo.py | 2 +- 5 files changed, 156 insertions(+), 2 deletions(-) create mode 100644 docs/modules/trpo.rst diff --git a/README.md b/README.md index 3e4c607e..815956d6 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,7 @@ See documentation for the full list of included features. - [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269) - [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044) - [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171) +- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477) **Gym Wrappers**: - [Time Feature Wrapper](https://arxiv.org/abs/1712.00378) diff --git a/docs/guide/algos.rst b/docs/guide/algos.rst index 879a84e1..81770b06 100644 --- a/docs/guide/algos.rst +++ b/docs/guide/algos.rst @@ -9,7 +9,8 @@ along with some useful characteristics: support for discrete/continuous actions, Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing ============ =========== ============ ================= =============== ================ TQC ✔️ ❌ ❌ ❌ ✔️ -QR-DQN ️❌ ️✔️ ❌ ❌ ✔️ +TRPO ✔️ ✔️ ✔️ ✔️ ✔️ +QR-DQN ️❌ ️✔️ ❌ ❌ ✔️ ============ =========== ============ ================= =============== ================ diff --git a/docs/index.rst b/docs/index.rst index 8e37c71b..91e24ad0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d :caption: RL Algorithms modules/tqc + modules/trpo modules/qrdqn modules/ppo_mask diff --git a/docs/modules/trpo.rst b/docs/modules/trpo.rst new file mode 100644 index 00000000..67829947 --- /dev/null +++ b/docs/modules/trpo.rst @@ -0,0 +1,151 @@ +.. _tqc: + +.. automodule:: sb3_contrib.trpo + +TRPO +==== + +`Trust Region Policy Optimization (TRPO) `_ +is an iterative approach for optimizing policies with guaranteed monotonic improvement. +.. rubric:: Available Policies + +.. autosummary:: + :nosignatures: + + MlpPolicy + CnnPolicy + MultiInputPolicy + + + +Notes +----- + +- Original paper: https://arxiv.org/abs/1502.05477 +- OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/ + + +Can I use? +---------- + +- Recurrent policies: ❌ +- Multi processing: ✔️ +- Gym spaces: + + +============= ====== =========== +Space Action Observation +============= ====== =========== +Discrete ✔️ ✔️ +Box ✔️ ✔️ +MultiDiscrete ✔️ ✔️ +MultiBinary ✔️ ✔️ +Dict ❌ ✔️ +============= ====== =========== + + +Example +------- + +.. code-block:: python + + import gym + import numpy as np + + from sb3_contrib import TRPO + + env = gym.make("Pendulum-v0") + + model = TRPO("MlpPolicy", env, verbose=1) + model.learn(total_timesteps=10000, log_interval=4) + model.save("trpo_pendulum") + + del model # remove to demonstrate saving and loading + + model = TRPO.load("trpo_pendulum") + + obs = env.reset() + while True: + action, _states = model.predict(obs, deterministic=True) + obs, reward, done, info = env.step(action) + env.render() + if done: + obs = env.reset() + + +Results +------- + +Result on the PyBullet benchmark (1M steps) using 3 seeds. +The complete learning curves are available in the `associated PR `_. + + +===================== ============ ============ +Environments PPO TRPO +===================== ============ ============ +HalfCheetah 1976 +/- 479 0000 +/- 157 +Ant 2364 +/- 120 0000 +/- 37 +Hopper 1567 +/- 339 0000 +/- 62 +Walker2D 1230 +/- 147 0000 +/- 94 +===================== ============ ============ + + +How to replicate the results? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Clone RL-Zoo and checkout the branch ``feat/trpo``: + +.. code-block:: bash + + git clone https://github.com/DLR-RM/rl-baselines3-zoo + cd rl-baselines3-zoo/ + git checkout feat/trpo + +Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): + +.. code-block:: bash + + python train.py --algo tqc --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + + +Plot the results: + +.. code-block:: bash + + python scripts/all_plots.py -a trpo -e HalfCheetah Ant Hopper Walker2D BipedalWalkerHardcore -f logs/ -o logs/trpo_results + python scripts/plot_from_file.py -i logs/trpo_results.pkl -latex -l TRPO + + +Parameters +---------- + +.. autoclass:: TRPO + :members: + :inherited-members: + +.. _trpo_policies: + +TRPO Policies +------------- + +.. autoclass:: MlpPolicy + :members: + :inherited-members: + +.. autoclass:: stable_baselines3.common.policies.ActorCriticPolicy + :members: + :noindex: + +.. autoclass:: CnnPolicy + :members: + +.. autoclass:: stable_baselines3.common.policies.ActorCriticCnnPolicy + :members: + :noindex: + +.. autoclass:: MultiInputPolicy + :members: + +.. autoclass:: stable_baselines3.common.policies.MultiInputActorCriticPolicy + :members: + :noindex: diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index aca8f9f5..042b1048 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -23,7 +23,7 @@ class TRPO(OnPolicyAlgorithm): Paper: https://arxiv.org/abs/1502.05477 Code: This implementation borrows code from OpenAI Spinning Up (https://github.com/openai/spinningup/) - and Stable Baselines (TRPO from https://github.com/hill-a/stable-baselines) + and Stable Baselines (TRPO from https://github.com/hill-a/stable-baselines) Introduction to TRPO: https://spinningup.openai.com/en/latest/algorithms/trpo.html From 439d79b2fa3564f6f11acf1af6802ccf92b07fa3 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Mon, 27 Dec 2021 16:36:59 +0100 Subject: [PATCH 22/28] Add more test and fixes for CNN --- sb3_contrib/common/utils.py | 2 +- sb3_contrib/trpo/trpo.py | 22 +++++++++++++++------- tests/test_cnn.py | 4 ++-- tests/test_dict_env.py | 24 ++++++++++++------------ tests/test_run.py | 18 +++++++++++++++++- tests/test_save_load.py | 9 +++++++-- 6 files changed, 54 insertions(+), 25 deletions(-) diff --git a/sb3_contrib/common/utils.py b/sb3_contrib/common/utils.py index 8748d25b..8e9e3519 100644 --- a/sb3_contrib/common/utils.py +++ b/sb3_contrib/common/utils.py @@ -160,4 +160,4 @@ def flat_grad( retain_graph=retain_graph, allow_unused=True, ) - return th.cat([grad.view(-1) for grad in grads if grad is not None]) + return th.cat([th.ravel(grad) for grad in grads if grad is not None]) diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index 042b1048..ae0427ab 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -50,6 +50,7 @@ class TRPO(OnPolicyAlgorithm): instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) + :param normalize_advantage: Whether to normalize or not the advantage :param target_kl: Limit the KL divergence between updates, because the clipping is not enough to prevent large update see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) @@ -83,6 +84,7 @@ def __init__( gae_lambda: float = 0.95, use_sde: bool = False, sde_sample_freq: int = -1, + normalize_advantage: bool = True, target_kl: float = 0.01, sub_sampling_factor: int = 1, tensorboard_log: Optional[str] = None, @@ -101,8 +103,8 @@ def __init__( n_steps=n_steps, gamma=gamma, gae_lambda=gae_lambda, - ent_coef=0.0, # TODO: add entropy bonus to surrogate objective - vf_coef=0.0, + ent_coef=0.0, # entropy bonus is not used by TRPO + vf_coef=0.0, # Value function is optimized separately max_grad_norm=0.0, use_sde=use_sde, sde_sample_freq=sde_sample_freq, @@ -122,15 +124,17 @@ def __init__( ), ) + self.normalize_advantage = normalize_advantage # Sanity check, otherwise it will lead to noisy gradient and NaN # because of the advantage normalization if self.env is not None: # Check that `n_steps * n_envs > 1` to avoid NaN # when doing advantage normalization buffer_size = self.env.num_envs * self.n_steps - assert ( - buffer_size > 1 - ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" + if normalize_advantage: + assert ( + buffer_size > 1 + ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" # Check that the rollout buffer size is a multiple of the mini-batch size untruncated_batches = buffer_size // batch_size if buffer_size % batch_size > 0: @@ -143,8 +147,10 @@ def __init__( f"Info: (n_steps={self.n_steps} and n_envs={self.env.num_envs})" ) self.batch_size = batch_size + # Conjugate gradients parameters self.cg_max_steps = cg_max_steps self.cg_damping = cg_damping + # Backtracking line search parameters self.line_search_shrinking_factor = line_search_shrinking_factor self.line_search_max_iter = line_search_max_iter self.target_kl = target_kl @@ -245,7 +251,8 @@ def train(self) -> None: # Re-sample the noise matrix because the log_std has changed if self.use_sde: - self.policy.reset_noise(self.batch_size) + # batch_size is only used for the value function + self.policy.reset_noise(actions.shape[0]) with th.no_grad(): # Note: is copy enough, no need for deepcopy? @@ -257,7 +264,8 @@ def train(self) -> None: log_prob = distribution.log_prob(actions) advantages = rollout_data.advantages - advantages = (advantages - advantages.mean()) / (rollout_data.advantages.std() + 1e-8) + if self.normalize_advantage: + advantages = (advantages - advantages.mean()) / (rollout_data.advantages.std() + 1e-8) # ratio between old and new policy, should be one at the first iteration ratio = th.exp(log_prob - rollout_data.old_log_prob) diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 6c277856..e570aab6 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -8,10 +8,10 @@ from stable_baselines3.common.utils import zip_strict from stable_baselines3.common.vec_env import VecTransposeImage, is_vecenv_wrapped -from sb3_contrib import QRDQN, TQC +from sb3_contrib import QRDQN, TQC, TRPO -@pytest.mark.parametrize("model_class", [TQC, QRDQN]) +@pytest.mark.parametrize("model_class", [TQC, QRDQN, TRPO]) def test_cnn(tmp_path, model_class): SAVE_NAME = "cnn_model.zip" # Fake grayscale with frameskip diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 86f5b5e0..38757ab3 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -6,7 +6,7 @@ from stable_baselines3.common.evaluation import evaluate_policy from stable_baselines3.common.vec_env import DummyVecEnv, VecFrameStack, VecNormalize -from sb3_contrib import QRDQN, TQC +from sb3_contrib import QRDQN, TQC, TRPO class DummyDictEnv(gym.Env): @@ -78,7 +78,7 @@ def render(self, mode="human"): pass -@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO]) def test_consistency(model_class): """ Make sure that dict obs with vector only vs using flatten obs is equivalent. @@ -94,7 +94,7 @@ def test_consistency(model_class): kwargs = {} n_steps = 256 - if model_class in {}: + if model_class in {TRPO}: kwargs = dict( n_steps=128, ) @@ -124,7 +124,7 @@ def test_consistency(model_class): assert np.allclose(action_1, action_2) -@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO]) @pytest.mark.parametrize("channel_last", [False, True]) def test_dict_spaces(model_class, channel_last): """ @@ -138,11 +138,11 @@ def test_dict_spaces(model_class, channel_last): kwargs = {} n_steps = 256 - if model_class in {}: + if model_class in {TRPO}: kwargs = dict( n_steps=128, policy_kwargs=dict( - net_arch=[32], + net_arch=[dict(pi=[32], vf=[32])], features_extractor_kwargs=dict(cnn_output_dim=32), ), ) @@ -169,7 +169,7 @@ def test_dict_spaces(model_class, channel_last): evaluate_policy(model, env, n_eval_episodes=5, warn=False) -@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO]) @pytest.mark.parametrize("channel_last", [False, True]) def test_dict_vec_framestack(model_class, channel_last): """ @@ -187,11 +187,11 @@ def test_dict_vec_framestack(model_class, channel_last): kwargs = {} n_steps = 256 - if model_class in {}: + if model_class in {TRPO}: kwargs = dict( n_steps=128, policy_kwargs=dict( - net_arch=[32], + net_arch=[dict(pi=[32], vf=[32])], features_extractor_kwargs=dict(cnn_output_dim=32), ), ) @@ -218,7 +218,7 @@ def test_dict_vec_framestack(model_class, channel_last): evaluate_policy(model, env, n_eval_episodes=5, warn=False) -@pytest.mark.parametrize("model_class", [QRDQN, TQC]) +@pytest.mark.parametrize("model_class", [QRDQN, TQC, TRPO]) def test_vec_normalize(model_class): """ Additional tests to check observation space support @@ -230,11 +230,11 @@ def test_vec_normalize(model_class): kwargs = {} n_steps = 256 - if model_class in {}: + if model_class in {TRPO}: kwargs = dict( n_steps=128, policy_kwargs=dict( - net_arch=[32], + net_arch=[dict(pi=[32], vf=[32])], ), ) else: diff --git a/tests/test_run.py b/tests/test_run.py index 56ad7d90..ab2d311e 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -62,7 +62,23 @@ def test_qrdqn(): @pytest.mark.parametrize("env_id", ["CartPole-v1", "Pendulum-v0"]) def test_trpo(env_id): - model = TRPO("MlpPolicy", env_id, n_steps=64, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1) + model = TRPO("MlpPolicy", env_id, n_steps=128, seed=0, policy_kwargs=dict(net_arch=[16]), verbose=1) + model.learn(total_timesteps=500) + + +def test_trpo_params(): + # Test with gSDE and subsampling + model = TRPO( + "MlpPolicy", + "Pendulum-v0", + n_steps=64, + batch_size=32, + use_sde=True, + sub_sampling_factor=4, + seed=0, + policy_kwargs=dict(net_arch=[dict(pi=[32], vf=[32])]), + verbose=1, + ) model.learn(total_timesteps=500) diff --git a/tests/test_save_load.py b/tests/test_save_load.py index d2ee3a21..10202e76 100644 --- a/tests/test_save_load.py +++ b/tests/test_save_load.py @@ -12,9 +12,9 @@ from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv -from sb3_contrib import QRDQN, TQC +from sb3_contrib import QRDQN, TQC, TRPO -MODEL_LIST = [TQC, QRDQN] +MODEL_LIST = [TQC, QRDQN, TRPO] def select_env(model_class: BaseAlgorithm) -> gym.Env: @@ -277,6 +277,11 @@ def test_save_load_policy(tmp_path, model_class, policy_str): learning_starts=100, policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), ) + else: + kwargs = dict( + n_steps=128, + policy_kwargs=dict(features_extractor_kwargs=dict(features_dim=32)), + ) env = FakeImageEnv(screen_height=40, screen_width=40, n_channels=2, discrete=model_class == QRDQN) # Reduce number of quantiles for faster tests From d9483dcdf8b4ca8acbe1dc3152775a28ce30bac8 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 28 Dec 2021 16:09:56 +0100 Subject: [PATCH 23/28] Update doc + add benchmark --- docs/guide/examples.rst | 13 +++++++++++++ docs/modules/trpo.rst | 28 ++++++++++++++-------------- 2 files changed, 27 insertions(+), 14 deletions(-) diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index 7d707888..3af55613 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -44,3 +44,16 @@ Train a PPO with invalid action masking agent on a toy environment. model = MaskablePPO("MlpPolicy", env, verbose=1) model.learn(5000) model.save("qrdqn_cartpole") + + TRPO + ---- + + Train a Trust Region Policy Optimization (TRPO) agent on the Pendulum environment. + + .. code-block:: python + + from sb3_contrib import TRPO + + model = TRPO("MlpPolicy", "Pendulum-v0", gamma=0.9, verbose=1) + model.learn(total_timesteps=100_000, log_interval=4) + model.save("trpo_pendulum") diff --git a/docs/modules/trpo.rst b/docs/modules/trpo.rst index 67829947..4ff88281 100644 --- a/docs/modules/trpo.rst +++ b/docs/modules/trpo.rst @@ -7,6 +7,7 @@ TRPO `Trust Region Policy Optimization (TRPO) `_ is an iterative approach for optimizing policies with guaranteed monotonic improvement. + .. rubric:: Available Policies .. autosummary:: @@ -17,7 +18,6 @@ is an iterative approach for optimizing policies with guaranteed monotonic impro MultiInputPolicy - Notes ----- @@ -76,18 +76,19 @@ Example Results ------- -Result on the PyBullet benchmark (1M steps) using 3 seeds. +Result on the MuJoCo benchmark (1M steps on ``-v3`` envs with MuJoCo v2.1.0) using 3 seeds. The complete learning curves are available in the `associated PR `_. -===================== ============ ============ -Environments PPO TRPO -===================== ============ ============ -HalfCheetah 1976 +/- 479 0000 +/- 157 -Ant 2364 +/- 120 0000 +/- 37 -Hopper 1567 +/- 339 0000 +/- 62 -Walker2D 1230 +/- 147 0000 +/- 94 -===================== ============ ============ +===================== ============ +Environments TRPO +===================== ============ +HalfCheetah 1803 +/- 46 +Ant 3554 +/- 591 +Hopper 3372 +/- 215 +Walker2d 4502 +/- 234 +Swimmer 359 +/- 2 +===================== ============ How to replicate the results? @@ -97,22 +98,21 @@ Clone RL-Zoo and checkout the branch ``feat/trpo``: .. code-block:: bash - git clone https://github.com/DLR-RM/rl-baselines3-zoo + git clone https://github.com/cyprienc/rl-baselines3-zoo cd rl-baselines3-zoo/ - git checkout feat/trpo Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above): .. code-block:: bash - python train.py --algo tqc --env $ENV_ID --eval-episodes 10 --eval-freq 10000 + python train.py --algo tqc --env $ENV_ID --n-eval-envs 10 --eval-episodes 20 --eval-freq 50000 Plot the results: .. code-block:: bash - python scripts/all_plots.py -a trpo -e HalfCheetah Ant Hopper Walker2D BipedalWalkerHardcore -f logs/ -o logs/trpo_results + python scripts/all_plots.py -a trpo -e HalfCheetah Ant Hopper Walker2d Swimmer -f logs/ -o logs/trpo_results python scripts/plot_from_file.py -i logs/trpo_results.pkl -latex -l TRPO From fff84e4c369ad40764312fa1adaa26446c88ec5f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 28 Dec 2021 16:49:35 +0100 Subject: [PATCH 24/28] Add tests + update doc --- docs/common/utils.rst | 7 +++++++ sb3_contrib/common/utils.py | 6 ++---- sb3_contrib/trpo/trpo.py | 15 ++++++++------- tests/test_utils.py | 31 ++++++++++++++++++++++++++++++- 4 files changed, 47 insertions(+), 12 deletions(-) create mode 100644 docs/common/utils.rst diff --git a/docs/common/utils.rst b/docs/common/utils.rst new file mode 100644 index 00000000..79fd9bf2 --- /dev/null +++ b/docs/common/utils.rst @@ -0,0 +1,7 @@ +.. _utils: + +Utils +===== + +.. automodule:: sb3_contrib.common.utils + :members: diff --git a/sb3_contrib/common/utils.py b/sb3_contrib/common/utils.py index 8e9e3519..bb94f0b0 100644 --- a/sb3_contrib/common/utils.py +++ b/sb3_contrib/common/utils.py @@ -70,7 +70,6 @@ def quantile_huber_loss( return loss -# TODO: write regression tests def conjugate_gradient_solver( matrix_vector_dot_func: Callable[[th.Tensor], th.Tensor], b, @@ -96,7 +95,7 @@ def conjugate_gradient_solver( :param residual_tol: residual tolerance for early stopping of the solving (default is 1e-10) :return x: - the approximate solution to the system of equations defined by Avp_fun + the approximate solution to the system of equations defined by `matrix_vector_dot_func` and b """ @@ -134,7 +133,6 @@ def conjugate_gradient_solver( p = residual + beta * p -# TODO: test def flat_grad( output, parameters: Sequence[nn.parameter.Parameter], @@ -146,7 +144,7 @@ def flat_grad( Order of parameters is preserved. :param output: functional output to compute the gradient for - :param parameters: sequence of `Parameter` + :param parameters: sequence of ``Parameter`` :param retain_graph – If ``False``, the graph used to compute the grad will be freed. Defaults to the value of ``create_graph``. :param create_graph – If ``True``, graph of the derivative will be constructed, diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index ae0427ab..66f42dd0 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -132,9 +132,10 @@ def __init__( # when doing advantage normalization buffer_size = self.env.num_envs * self.n_steps if normalize_advantage: - assert ( - buffer_size > 1 - ), f"`n_steps * n_envs` must be greater than 1. Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" + assert buffer_size > 1, ( + "`n_steps * n_envs` must be greater than 1. " + f"Currently n_steps={self.n_steps} and n_envs={self.env.num_envs}" + ) # Check that the rollout buffer size is a multiple of the mini-batch size untruncated_batches = buffer_size // batch_size if buffer_size % batch_size > 0: @@ -282,11 +283,11 @@ def train(self) -> None: actor_params, policy_objective_gradients, grad_kl, grad_shape = self._compute_actor_grad(kl_div, policy_objective) # Hessian-vector dot product function used in the conjugate gradient step - hessian_vector_product = partial(self.hessian_vector_product, actor_params, grad_kl) + hessian_vector_product_fn = partial(self.hessian_vector_product, actor_params, grad_kl) # Computing search direction search_direction = conjugate_gradient_solver( - hessian_vector_product, + hessian_vector_product_fn, policy_objective_gradients, max_iter=self.cg_max_steps, ) @@ -294,7 +295,7 @@ def train(self) -> None: # Maximal step length line_search_max_step_size = 2 * self.target_kl line_search_max_step_size /= th.matmul( - search_direction, hessian_vector_product(search_direction, retain_graph=False) + search_direction, hessian_vector_product_fn(search_direction, retain_graph=False) ) line_search_max_step_size = th.sqrt(line_search_max_step_size) @@ -389,7 +390,7 @@ def hessian_vector_product( :param grad_kl: flattened gradient of the KL divergence between the old and new policy :param vector: vector to compute the dot product the hessian-vector dot product with :param retain_graph: if True, the graph will be kept after computing the Hessian - :return: Hessian-vector dot product + :return: Hessian-vector dot product (with damping) """ jacobian_vector_product = (grad_kl * vector).sum() return flat_grad(jacobian_vector_product, params, retain_graph=retain_graph) + self.cg_damping * vector diff --git a/tests/test_utils.py b/tests/test_utils.py index 97b740f4..b7455b8a 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,8 +1,9 @@ import numpy as np import pytest import torch as th +from stable_baselines3.common.utils import set_random_seed -from sb3_contrib.common.utils import quantile_huber_loss +from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad, quantile_huber_loss def test_quantile_huber_loss(): @@ -17,3 +18,31 @@ def test_quantile_huber_loss(): quantile_huber_loss(th.zeros(4, 4), th.zeros(3, 4)) with pytest.raises(ValueError): quantile_huber_loss(th.zeros(4, 4, 4, 4), th.zeros(4, 4, 4, 4)) + + +def test_cg(): + # Test that conjugate gradient can actually solve + # Ax = b when the A^-1 is known + set_random_seed(4) + A = th.ones(3, 3) + # Symmetric matrix + A[0, 1] = 2 + A[1, 0] = 2 + x = th.ones(3) + th.rand(3) + b = A @ x + + def matrix_vector_dot_func(vector): + return A @ vector + + x_approx = conjugate_gradient_solver(matrix_vector_dot_func, b, max_iter=5, residual_tol=1e-10) + assert th.allclose(x_approx, x) + + +def test_flat_grad(): + n_parameters = 12 # 3 * (2 * 2) + x = th.nn.Parameter(th.ones(2, 2, requires_grad=True)) + y = (x ** 2).sum() + flat_grad_out = flat_grad(y, [x, x, x]) + assert len(flat_grad_out.shape) == 1 + # dy/dx = 2 + assert th.allclose(flat_grad_out, th.ones(n_parameters) * 2) From 95dddf4671d3d3ead560848b9f85bb7591e84d2f Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Tue, 28 Dec 2021 16:51:13 +0100 Subject: [PATCH 25/28] Fix doc --- docs/index.rst | 1 + sb3_contrib/common/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/index.rst b/docs/index.rst index 91e24ad0..ac610bd0 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -40,6 +40,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d :maxdepth: 1 :caption: Common + common/utils common/wrappers .. toctree:: diff --git a/sb3_contrib/common/utils.py b/sb3_contrib/common/utils.py index bb94f0b0..30f26aba 100644 --- a/sb3_contrib/common/utils.py +++ b/sb3_contrib/common/utils.py @@ -145,9 +145,9 @@ def flat_grad( :param output: functional output to compute the gradient for :param parameters: sequence of ``Parameter`` - :param retain_graph – If ``False``, the graph used to compute the grad will be freed. + :param retain_graph: – If ``False``, the graph used to compute the grad will be freed. Defaults to the value of ``create_graph``. - :param create_graph – If ``True``, graph of the derivative will be constructed, + :param create_graph: – If ``True``, graph of the derivative will be constructed, allowing to compute higher order derivative products. Default: ``False``. :return: Tensor containing the flattened gradients """ From 661fe15f374ee5bef4c745c5f41a46081415fba6 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 29 Dec 2021 11:14:56 +0100 Subject: [PATCH 26/28] Improve names for conjugate gradient --- sb3_contrib/common/utils.py | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/sb3_contrib/common/utils.py b/sb3_contrib/common/utils.py index 30f26aba..daa6863b 100644 --- a/sb3_contrib/common/utils.py +++ b/sb3_contrib/common/utils.py @@ -71,7 +71,7 @@ def quantile_huber_loss( def conjugate_gradient_solver( - matrix_vector_dot_func: Callable[[th.Tensor], th.Tensor], + matrix_vector_dot_fn: Callable[[th.Tensor], th.Tensor], b, max_iter=10, residual_tol=1e-10, @@ -86,7 +86,7 @@ def conjugate_gradient_solver( Reference: - https://epubs.siam.org/doi/abs/10.1137/1.9781611971446.ch6 - :param matrix_vector_dot_func: + :param matrix_vector_dot_fn: a function that right multiplies a matrix A by a vector v :param b: the right hand term in the set of linear equations Ax = b @@ -95,17 +95,18 @@ def conjugate_gradient_solver( :param residual_tol: residual tolerance for early stopping of the solving (default is 1e-10) :return x: - the approximate solution to the system of equations defined by `matrix_vector_dot_func` + the approximate solution to the system of equations defined by `matrix_vector_dot_fn` and b """ # The vector is not initialized at 0 because of the instability issues when the gradient becomes small. # A small random gaussian noise is used for the initialization. x = 1e-4 * th.randn_like(b) - residual = b - matrix_vector_dot_func(x) - r_dot = th.matmul(residual, residual) + residual = b - matrix_vector_dot_fn(x) + # Equivalent to th.linalg.norm(residual) ** 2 (L2 norm squared) + residual_squared_norm = th.matmul(residual, residual) - if r_dot < residual_tol: + if residual_squared_norm < residual_tol: # If the gradient becomes extremely small # The denominator in alpha will become zero # Leading to a division by zero @@ -114,22 +115,23 @@ def conjugate_gradient_solver( p = residual.clone() for i in range(max_iter): - Avp = matrix_vector_dot_func(p) + # A @ p (matrix vector multiplication) + A_dot_p = matrix_vector_dot_fn(p) - alpha = r_dot / p.dot(Avp) + alpha = residual_squared_norm / p.dot(A_dot_p) x += alpha * p if i == max_iter - 1: return x - residual -= alpha * Avp - new_r_dot = th.matmul(residual, residual) + residual -= alpha * A_dot_p + new_residual_squared_norm = th.matmul(residual, residual) - if new_r_dot < residual_tol: + if new_residual_squared_norm < residual_tol: return x - beta = new_r_dot / r_dot - r_dot = new_r_dot + beta = new_residual_squared_norm / residual_squared_norm + residual_squared_norm = new_residual_squared_norm p = residual + beta * p From a24e7c064bbcd65abd4e4a62a136e1413c7ee491 Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 29 Dec 2021 11:35:01 +0100 Subject: [PATCH 27/28] Update comments --- sb3_contrib/trpo/trpo.py | 16 ++++++++-------- tests/test_utils.py | 16 ++++++++++++++++ 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/sb3_contrib/trpo/trpo.py b/sb3_contrib/trpo/trpo.py index 66f42dd0..9c461031 100644 --- a/sb3_contrib/trpo/trpo.py +++ b/sb3_contrib/trpo/trpo.py @@ -44,17 +44,15 @@ class TRPO(OnPolicyAlgorithm): (i.e., ``theta_new = theta + alpha^i * step``) :param line_search_max_iter: maximum number of iteration for the backtracking line-search - :param n_critic_updates: number of critic updates per policy updates + :param n_critic_updates: number of critic updates per policy update :param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator :param use_sde: Whether to use generalized State Dependent Exploration (gSDE) instead of action noise exploration (default: False) :param sde_sample_freq: Sample a new noise matrix every n steps when using gSDE Default: -1 (only sample at the beginning of the rollout) :param normalize_advantage: Whether to normalize or not the advantage - :param target_kl: Limit the KL divergence between updates, - because the clipping is not enough to prevent large update - see issue #213 (cf https://github.com/hill-a/stable-baselines/issues/213) - By default, there is no limit on the kl div. + :param target_kl: Target Kullback-Leibler divergence between updates. + Should be small for stability. Values like 0.01, 0.05. :param sub_sampling_factor: Sub-sample the batch to make computation faster see p40-42 of John Schulman thesis http://joschu.net/docs/thesis.pdf :param tensorboard_log: the log location for tensorboard (if None, no logging) @@ -104,7 +102,7 @@ def __init__( gamma=gamma, gae_lambda=gae_lambda, ent_coef=0.0, # entropy bonus is not used by TRPO - vf_coef=0.0, # Value function is optimized separately + vf_coef=0.0, # value function is optimized separately max_grad_norm=0.0, use_sde=use_sde, sde_sample_freq=sde_sample_freq, @@ -167,7 +165,7 @@ def _compute_actor_grad( """ Compute actor gradients for kl div and surrogate objectives. - :param kl_div: The KL div objective + :param kl_div: The KL divergence objective :param policy_objective: The surrogate objective ("classic" policy gradient) :return: List of actor params, gradients and gradients shape. """ @@ -330,7 +328,9 @@ def train(self) -> None: # New KL-divergence kl_div = kl_divergence(distribution.distribution, old_distribution.distribution).mean() - # Constraint criteria + # Constraint criteria: + # we need to improve the surrogate policy objective + # while being close enough (in term of kl div) to the old policy if (kl_div < self.target_kl) and (new_policy_objective > policy_objective): is_line_search_success = True break diff --git a/tests/test_utils.py b/tests/test_utils.py index b7455b8a..c434dd23 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,6 +3,7 @@ import torch as th from stable_baselines3.common.utils import set_random_seed +from sb3_contrib import TRPO from sb3_contrib.common.utils import conjugate_gradient_solver, flat_grad, quantile_huber_loss @@ -46,3 +47,18 @@ def test_flat_grad(): assert len(flat_grad_out.shape) == 1 # dy/dx = 2 assert th.allclose(flat_grad_out, th.ones(n_parameters) * 2) + + +def test_trpo_warnings(): + """Test that TRPO warns and errors correctly on + problematic rollout buffer sizes""" + + # Only 1 step: advantage normalization will return NaN + with pytest.raises(AssertionError): + TRPO("MlpPolicy", "Pendulum-v0", n_steps=1) + # One step not advantage normalization: ok + TRPO("MlpPolicy", "Pendulum-v0", n_steps=1, normalize_advantage=False, batch_size=1) + + # Truncated mini-batch + with pytest.warns(UserWarning): + TRPO("MlpPolicy", "Pendulum-v0", n_steps=6, batch_size=8) From 342fe531bfa5af13a15099f16fbc71feceb5ba1e Mon Sep 17 00:00:00 2001 From: Antonin Raffin Date: Wed, 29 Dec 2021 11:43:11 +0100 Subject: [PATCH 28/28] Update changelog --- docs/misc/changelog.rst | 4 ++-- sb3_contrib/version.txt | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index c3e87119..1b117e6b 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -4,8 +4,9 @@ Changelog ========== -Release 1.3.1a6 (WIP) +Release 1.3.1a7 (WIP) ------------------------------- +**Add TRPO** Breaking Changes: ^^^^^^^^^^^^^^^^^ @@ -51,7 +52,6 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ - Added ``MaskablePPO`` algorithm (@kronion) -- Added ``TRPO`` (@cyprienc) - ``MaskablePPO`` Dictionary Observation support (@glmcdona) diff --git a/sb3_contrib/version.txt b/sb3_contrib/version.txt index e6eaed88..f6258073 100644 --- a/sb3_contrib/version.txt +++ b/sb3_contrib/version.txt @@ -1 +1 @@ -1.3.1a6 +1.3.1a7