From 78927d976a6ed1a68b9433a633585a11d859e82d Mon Sep 17 00:00:00 2001 From: unknown Date: Fri, 19 May 2023 17:56:30 +0800 Subject: [PATCH] update module, buffer and driver for off-policy algorithm --- openrl/algorithms/dqn.py | 155 +++++++++++++--------- openrl/buffers/offpolicy_buffer.py | 30 +++++ openrl/buffers/offpolicy_replay_data.py | 169 +++++++++++++++++++++++- openrl/drivers/offpolicy_driver.py | 80 ++++------- openrl/modules/dqn_module.py | 16 ++- openrl/runners/common/dqn_agent.py | 1 + 6 files changed, 331 insertions(+), 120 deletions(-) diff --git a/openrl/algorithms/dqn.py b/openrl/algorithms/dqn.py index 9f0a1be6..b0b7ca97 100644 --- a/openrl/algorithms/dqn.py +++ b/openrl/algorithms/dqn.py @@ -21,7 +21,7 @@ import numpy as np import torch import torch.nn as nn -from torch.nn.parallel import DistributedDataParallel +import torch.nn.functional as F from openrl.algorithms.base_algorithm import BaseAlgorithm from openrl.modules.networks.utils.distributed_utils import reduce_tensor @@ -37,59 +37,61 @@ def __init__( agent_num: int = 1, device: Union[str, torch.device] = "cpu", ) -> None: - self._use_share_model = cfg.use_share_model - self.use_joint_action_loss = cfg.use_joint_action_loss super(DQNAlgorithm, self).__init__(cfg, init_module, agent_num, device) + self.gamma = cfg.gamma + def dqn_update(self, sample, turn_on=True): for optimizer in self.algo_module.optimizers.values(): optimizer.zero_grad() ( obs_batch, + _, + next_obs_batch, + _, rnn_states_batch, + rnn_states_critic_batch, actions_batch, value_preds_batch, - return_batch, + rewards_batch, masks_batch, active_masks_batch, + old_action_log_probs_batch, + adv_targ, available_actions_batch, ) = sample value_preds_batch = check(value_preds_batch).to(**self.tpdv) - return_batch = check(return_batch).to(**self.tpdv) + rewards_batch = check(rewards_batch).to(**self.tpdv) active_masks_batch = check(active_masks_batch).to(**self.tpdv) if self.use_amp: with torch.cuda.amp.autocast(): - ( - loss_list, - value_loss, - policy_loss, - dist_entropy, - ratio, - ) = self.prepare_loss( + loss_list = self.prepare_loss( obs_batch, + next_obs_batch, rnn_states_batch, actions_batch, masks_batch, available_actions_batch, value_preds_batch, - return_batch, + rewards_batch, active_masks_batch, turn_on, ) for loss in loss_list: self.algo_module.scaler.scale(loss).backward() else: - loss_list, value_loss, policy_loss, dist_entropy, ratio = self.prepare_loss( + loss_list = self.prepare_loss( obs_batch, + next_obs_batch, rnn_states_batch, actions_batch, masks_batch, available_actions_batch, value_preds_batch, - return_batch, + rewards_batch, active_masks_batch, turn_on, ) @@ -97,42 +99,10 @@ def dqn_update(self, sample, turn_on=True): loss.backward() if "transformer" in self.algo_module.models: - if self._use_max_grad_norm: - grad_norm = nn.utils.clip_grad_norm_( - self.algo_module.models["transformer"].parameters(), - self.max_grad_norm, - ) - else: - grad_norm = get_gard_norm( - self.algo_module.models["transformer"].parameters() - ) - critic_grad_norm = grad_norm - actor_grad_norm = grad_norm - + raise NotImplementedError else: - if self._use_share_model: - actor_para = self.algo_module.models["model"].get_actor_para() - else: - actor_para = self.algo_module.models["policy"].parameters() - - if self._use_max_grad_norm: - actor_grad_norm = nn.utils.clip_grad_norm_( - actor_para, self.max_grad_norm - ) - else: - actor_grad_norm = get_gard_norm(actor_para) - - if self._use_share_model: - critic_para = self.algo_module.models["model"].get_critic_para() - else: - critic_para = self.algo_module.models["critic"].parameters() - - if self._use_max_grad_norm: - critic_grad_norm = nn.utils.clip_grad_norm_( - critic_para, self.max_grad_norm - ) - else: - critic_grad_norm = get_gard_norm(critic_para) + actor_para = self.algo_module.models["q_net"].parameters() + actor_grad_norm = get_gard_norm(actor_para) if self.use_amp: for optimizer in self.algo_module.optimizers.values(): @@ -149,14 +119,7 @@ def dqn_update(self, sample, turn_on=True): if self.world_size > 1: torch.cuda.synchronize() - return ( - value_loss, - critic_grad_norm, - policy_loss, - dist_entropy, - actor_grad_norm, - ratio, - ) + return loss def cal_value_loss( self, @@ -208,16 +171,86 @@ def to_single_np(self, input): def prepare_loss( self, obs_batch, + next_obs_batch, rnn_states_batch, actions_batch, masks_batch, available_actions_batch, value_preds_batch, - return_batch, + rewards_batch, active_masks_batch, turn_on, ): - raise NotImplementedError + loss_list = [] + critic_masks_batch = masks_batch + + ( + q_values, + max_next_q_values + ) = self.algo_module.evaluate_actions( + obs_batch, + next_obs_batch, + rnn_states_batch, + rewards_batch, + actions_batch, + masks_batch, + available_actions_batch, + active_masks_batch, + critic_masks_batch=critic_masks_batch, + ) + + q_targets = rewards_batch + self.gamma * max_next_q_values + q_loss = torch.mean(F.mse_loss(q_values, q_targets)) # 均方误差损失函数 + + loss_list.append(q_loss) + return loss_list def train(self, buffer, turn_on=True): - raise NotImplementedError + train_info = {} + + train_info["q_loss"] = 0 + + if self.world_size > 1: + train_info["reduced_q_loss"] = 0 + + # todo add rnn and transformer + # update once + for _ in range(1): + if "transformer" in self.algo_module.models: + raise NotImplementedError + elif self._use_recurrent_policy: + raise NotImplementedError + elif self._use_naive_recurrent: + raise NotImplementedError + else: + data_generator = buffer.feed_forward_generator( + _, self.num_mini_batch + ) + + for sample in data_generator: + ( + q_loss, + critic_grad_norm, + policy_loss, + dist_entropy, + actor_grad_norm, + ratio, + ) = self.dqn_update(sample, turn_on) + + if self.world_size > 1: + train_info["reduced_q_loss"] += reduce_tensor( + q_loss.data, self.world_size + ) + + train_info["q_loss"] += q_loss.item() + + num_updates = 1 * self.num_mini_batch + + for k in train_info.keys(): + train_info[k] /= num_updates + + for optimizer in self.algo_module.optimizers.values(): + if hasattr(optimizer, "sync_lookahead"): + optimizer.sync_lookahead() + + return train_info \ No newline at end of file diff --git a/openrl/buffers/offpolicy_buffer.py b/openrl/buffers/offpolicy_buffer.py index 31ad2ab8..05f2a0e2 100644 --- a/openrl/buffers/offpolicy_buffer.py +++ b/openrl/buffers/offpolicy_buffer.py @@ -36,6 +36,36 @@ def __init__( episode_length, ) + def insert( + self, + raw_obs, + next_raw_obs, + rnn_states, + rnn_states_critic, + actions, + action_log_probs, + value_preds, + rewards, + masks, + bad_masks=None, + active_masks=None, + available_actions=None, + ): + self.data.insert( + raw_obs, + next_raw_obs, + rnn_states, + rnn_states_critic, + actions, + action_log_probs, + value_preds, + rewards, + masks, + bad_masks, + active_masks, + available_actions, + ) + def get_buffer_size(self): if self.data.first_insert_flag: return self.data.step diff --git a/openrl/buffers/offpolicy_replay_data.py b/openrl/buffers/offpolicy_replay_data.py index 13c76cd3..05aa9cab 100644 --- a/openrl/buffers/offpolicy_replay_data.py +++ b/openrl/buffers/offpolicy_replay_data.py @@ -24,7 +24,12 @@ from openrl.buffers.replay_data import ReplayData from openrl.buffers.utils.obs_data import ObsData -from openrl.buffers.utils.util import get_critic_obs, get_policy_obs +from openrl.buffers.utils.util import ( + get_critic_obs, + get_policy_obs, + get_critic_obs_space, + get_policy_obs_space +) class OffPolicyReplayData(ReplayData): @@ -45,6 +50,27 @@ def __init__( data_client, episode_length, ) + + policy_obs_shape = get_policy_obs_space(obs_space) + critic_obs_shape = get_critic_obs_space(obs_space) + self.next_policy_obs = np.zeros( + ( + self.episode_length + 1, + self.n_rollout_threads, + num_agents, + *policy_obs_shape, + ), + dtype=np.float32, + ) + self.next_critic_obs = np.zeros( + ( + self.episode_length + 1, + self.n_rollout_threads, + num_agents, + *critic_obs_shape, + ), + dtype=np.float32, + ) self.first_insert_flag = True def dict_insert(self, data): @@ -53,9 +79,15 @@ def dict_insert(self, data): self.critic_obs[key][self.step + 1] = data["critic_obs"][key].copy() for key in self.policy_obs.keys(): self.policy_obs[key][self.step + 1] = data["policy_obs"][key].copy() + for key in self.next_policy_obs.keys(): + self.next_policy_obs[key][self.step + 1] = data["next_policy_obs"][key].copy() + for key in self.next_critic_obs.keys(): + self.next_critic_obs[key][self.step + 1] = data["next_critic_obs"][key].copy() else: self.critic_obs[self.step + 1] = data["critic_obs"].copy() self.policy_obs[self.step + 1] = data["policy_obs"].copy() + self.next_policy_obs[self.step + 1] = data["next_policy_obs"].copy() + self.next_critic_obs[self.step + 1] = data["next_critic_obs"].copy() if "rnn_states" in data: self.rnn_states[self.step + 1] = data["rnn_states"].copy() @@ -87,6 +119,7 @@ def dict_insert(self, data): def insert( self, raw_obs, + next_raw_obs, rnn_states, rnn_states_critic, actions, @@ -100,14 +133,22 @@ def insert( ): critic_obs = get_critic_obs(raw_obs) policy_obs = get_policy_obs(raw_obs) + next_critic_obs = get_critic_obs(next_raw_obs) + next_policy_obs = get_policy_obs(next_raw_obs) if self._mixed_obs: for key in self.critic_obs.keys(): self.critic_obs[key][self.step + 1] = critic_obs[key].copy() for key in self.policy_obs.keys(): self.policy_obs[key][self.step + 1] = policy_obs[key].copy() + for key in self.next_critic_obs.keys(): + self.next_critic_obs[key][self.step + 1] = next_critic_obs[key].copy() + for key in self.next_policy_obs.keys(): + self.next_policy_obs[key][self.step + 1] = next_policy_obs[key].copy() else: self.critic_obs[self.step + 1] = critic_obs.copy() self.policy_obs[self.step + 1] = policy_obs.copy() + self.next_critic_obs[self.step + 1] = next_critic_obs.copy() + self.next_policy_obs[self.step + 1] = next_policy_obs.copy() self.rnn_states[self.step + 1] = rnn_states.copy() self.rnn_states_critic[self.step + 1] = rnn_states_critic.copy() @@ -136,9 +177,16 @@ def after_update(self): self.critic_obs[key][0] = self.critic_obs[key][-1].copy() for key in self.policy_obs.keys(): self.policy_obs[key][0] = self.policy_obs[key][-1].copy() + for key in self.next_critic_obs.keys(): + self.next_critic_obs[key][0] = self.next_critic_obs[key][-1].copy() + for key in self.next_policy_obs.keys(): + self.next_policy_obs[key][0] = self.next_policy_obs[key][-1].copy() else: self.critic_obs[0] = self.critic_obs[-1].copy() self.policy_obs[0] = self.policy_obs[-1].copy() + self.next_critic_obs[0] = self.next_critic_obs[-1].copy() + self.next_policy_obs[0] = self.next_policy_obs[-1].copy() + self.rnn_states[0] = self.rnn_states[-1].copy() self.rnn_states_critic[0] = self.rnn_states_critic[-1].copy() self.masks[0] = self.masks[-1].copy() @@ -149,3 +197,122 @@ def after_update(self): def compute_returns(self, next_value, value_normalizer=None): pass + + def feed_forward_generator( + self, + advantages, + num_mini_batch=None, + mini_batch_size=None, + critic_obs_process_func=None, + ): + episode_length, n_rollout_threads, num_agents = self.rewards.shape[0:3] + batch_size = n_rollout_threads * episode_length * num_agents + + if mini_batch_size is None: + assert ( + batch_size >= num_mini_batch + ), ( + "DQN requires the number of processes ({}) " + "* number of steps ({}) * number of agents ({}) = {} " + "to be greater than or equal to the number of DQN mini batches ({})." + "".format( + n_rollout_threads, + episode_length, + num_agents, + n_rollout_threads * episode_length * num_agents, + num_mini_batch, + ) + ) + mini_batch_size = batch_size // num_mini_batch + + sampler = BatchSampler( + SubsetRandomSampler(range(batch_size)), mini_batch_size, drop_last=True + ) + + if self._mixed_obs: + critic_obs = {} + policy_obs = {} + next_critic_obs = {} + next_policy_obs = {} + for key in self.critic_obs.keys(): + critic_obs[key] = self.critic_obs[key][:-1].reshape( + -1, *self.critic_obs[key].shape[3:] + ) + for key in self.policy_obs.keys(): + policy_obs[key] = self.policy_obs[key][:-1].reshape( + -1, *self.policy_obs[key].shape[3:] + ) + for key in self.next_critic_obs.keys(): + next_critic_obs[key] = self.next_critic_obs[key][:-1].reshape( + -1, *self.next_critic_obs[key].shape[3:] + ) + for key in self.next_policy_obs.keys(): + next_policy_obs[key] = self.next_policy_obs[key][:-1].reshape( + -1, *self.next_policy_obs[key].shape[3:] + ) + else: + critic_obs = self.critic_obs[:-1].reshape(-1, *self.critic_obs.shape[3:]) + policy_obs = self.policy_obs[:-1].reshape(-1, *self.policy_obs.shape[3:]) + next_critic_obs = self.next_critic_obs[:-1].reshape(-1, *self.next_critic_obs.shape[3:]) + next_policy_obs = self.next_policy_obs[:-1].reshape(-1, *self.next_policy_obs.shape[3:]) + + rnn_states = self.rnn_states[:-1].reshape(-1, *self.rnn_states.shape[3:]) + rnn_states_critic = self.rnn_states_critic[:-1].reshape( + -1, *self.rnn_states_critic.shape[3:] + ) + actions = self.actions.reshape(-1, self.actions.shape[-1]) + if self.available_actions is not None: + available_actions = self.available_actions[:-1].reshape( + -1, self.available_actions.shape[-1] + ) + value_preds = self.value_preds[:-1].reshape(-1, 1) + rewards = self.rewards[:-1].reshape(-1, 1) + masks = self.masks[:-1].reshape(-1, 1) + active_masks = self.active_masks[:-1].reshape(-1, 1) + action_log_probs = self.action_log_probs.reshape( + -1, self.action_log_probs.shape[-1] + ) + if advantages is not None: + advantages = advantages.reshape(-1, 1) + + for indices in sampler: + # obs size [T+1 N M Dim]-->[T N M Dim]-->[T*N*M,Dim]-->[index,Dim] + if self._mixed_obs: + critic_obs_batch = {} + policy_obs_batch = {} + next_critic_obs_batch = {} + next_policy_obs_batch = {} + for key in critic_obs.keys(): + critic_obs_batch[key] = critic_obs[key][indices] + for key in policy_obs.keys(): + policy_obs_batch[key] = policy_obs[key][indices] + for key in next_critic_obs.keys(): + next_critic_obs_batch[key] = next_critic_obs[key][indices] + for key in next_policy_obs.keys(): + next_policy_obs_batch[key] = next_policy_obs[key][indices] + else: + critic_obs_batch = critic_obs[indices] + policy_obs_batch = policy_obs[indices] + next_critic_obs_batch = next_critic_obs[indices] + next_policy_obs_batch = next_policy_obs[indices] + + rnn_states_batch = rnn_states[indices] + rnn_states_critic_batch = rnn_states_critic[indices] + actions_batch = actions[indices] + if self.available_actions is not None: + available_actions_batch = available_actions[indices] + else: + available_actions_batch = None + value_preds_batch = value_preds[indices] + rewards_batch = rewards[indices] + masks_batch = masks[indices] + active_masks_batch = active_masks[indices] + old_action_log_probs_batch = action_log_probs[indices] + if advantages is None: + adv_targ = None + else: + adv_targ = advantages[indices] + if critic_obs_process_func is not None: + critic_obs_batch = critic_obs_process_func(critic_obs_batch) + + yield critic_obs_batch, policy_obs_batch, next_critic_obs_batch, next_policy_obs_batch,rnn_states_batch, rnn_states_critic_batch, actions_batch, value_preds_batch, rewards_batch, masks_batch, active_masks_batch, old_action_log_probs_batch, adv_targ, available_actions_batch diff --git a/openrl/drivers/offpolicy_driver.py b/openrl/drivers/offpolicy_driver.py index c6430b62..2d9b9319 100644 --- a/openrl/drivers/offpolicy_driver.py +++ b/openrl/drivers/offpolicy_driver.py @@ -57,12 +57,7 @@ def _inner_loop( self.buffer.after_update() else: train_infos = { - "value_loss": 0, - "policy_loss": 0, - "dist_entropy": 0, - "actor_grad_norm": 0, - "critic_grad_norm": 0, - "ratio": 0, + "q_loss": 0 } self.total_num_steps = ( @@ -77,14 +72,13 @@ def _inner_loop( def add2buffer(self, data): ( obs, + next_obs, rewards, dones, infos, - values, + q_values, actions, - action_log_probs, rnn_states, - rnn_states_critic, ) = data rnn_states[dones] = np.zeros( @@ -92,20 +86,20 @@ def add2buffer(self, data): dtype=np.float32, ) - rnn_states_critic[dones] = np.zeros( - (dones.sum(), *self.buffer.data.rnn_states_critic.shape[3:]), - dtype=np.float32, - ) masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32) masks[dones] = np.zeros((dones.sum(), 1), dtype=np.float32) + rnn_states_critic = rnn_states + action_log_probs = actions + self.buffer.insert( obs, + next_obs, rnn_states, rnn_states_critic, actions, action_log_probs, - values, + q_values, rewards, masks, ) @@ -114,32 +108,41 @@ def actor_rollout(self): self.trainer.prep_rollout() import time + q_values, actions, rnn_states = self.act(0) + extra_data = { + "q_values": q_values, + "step": 0, + "buffer": self.buffer, + } + obs, rewards, dones, infos = self.envs.step(actions, extra_data) + + # todo how to handle next obs in initialized state and terminal state + next_obs, rewards, dones, infos = self.envs.step(actions, extra_data) for step in range(self.episode_length): - values, actions, action_log_probs, rnn_states, rnn_states_critic = self.act( + q_values, actions, rnn_states = self.act( step ) extra_data = { - "values": values, - "action_log_probs": action_log_probs, + "q_values": q_values, "step": step, "buffer": self.buffer, } - obs, rewards, dones, infos = self.envs.step(actions, extra_data) + # todo how to handle next obs in initialized state and terminal state + next_obs, rewards, dones, infos = self.envs.step(actions, extra_data) data = ( obs, + next_obs, rewards, dones, infos, - values, + q_values, actions, - action_log_probs, rnn_states, - rnn_states_critic, ) - + obs = next_obs self.add2buffer(data) batch_rew_infos = self.envs.batch_rewards(self.buffer) @@ -151,33 +154,6 @@ def actor_rollout(self): else: return batch_rew_infos - @torch.no_grad() - def compute_returns(self): - self.trainer.prep_rollout() - - next_values = self.trainer.algo_module.get_values( - self.buffer.data.get_batch_data("critic_obs", -1), - np.concatenate(self.buffer.data.rnn_states_critic[-1]), - np.concatenate(self.buffer.data.masks[-1]), - ) - - next_values = np.array( - np.split(_t2n(next_values), self.learner_n_rollout_threads) - ) - if "critic" in self.trainer.algo_module.models and isinstance( - self.trainer.algo_module.models["critic"], DistributedDataParallel - ): - value_normalizer = self.trainer.algo_module.models[ - "critic" - ].module.value_normalizer - elif "model" in self.trainer.algo_module.models and isinstance( - self.trainer.algo_module.models["model"], DistributedDataParallel - ): - value_normalizer = self.trainer.algo_module.models["model"].value_normalizer - else: - value_normalizer = self.trainer.algo_module.get_critic_value_normalizer() - self.buffer.compute_returns(next_values, value_normalizer) - @torch.no_grad() def act( self, @@ -205,12 +181,12 @@ def act( * step ) if random.random() > epsilon: - action = q_values.argmax().item() + actions = q_values.argmax().item() else: - action = q_values.argmax().item() + actions = q_values.argmax().item() return ( q_values, - action, + actions, rnn_states, ) diff --git a/openrl/modules/dqn_module.py b/openrl/modules/dqn_module.py index 2a46458b..4b2ddbe9 100644 --- a/openrl/modules/dqn_module.py +++ b/openrl/modules/dqn_module.py @@ -95,8 +95,11 @@ def get_values(self, obs, rnn_states_critic, masks): def evaluate_actions( self, - obs, - rnn_states, + obs_batch, + next_obs_batch, + rnn_states_batch, + rewards_batch, + actions_batch, masks, available_actions=None, masks_batch=None, @@ -105,10 +108,11 @@ def evaluate_actions( masks_batch = masks q_values, _ = self.models["q_net"]( - obs, rnn_states, masks_batch, available_actions + obs_batch, rnn_states_batch, masks_batch, available_actions ) + max_next_q_values, _ = self.algo_module.models["target_q_net"](next_obs_batch, rnn_states_batch, masks_batch, available_actions) - return q_values + return q_values, max_next_q_values def act(self, obs, rnn_states_actor, masks, available_actions=None): model = self.models["q_net"] @@ -119,8 +123,8 @@ def act(self, obs, rnn_states_actor, masks, available_actions=None): masks, available_actions, ) - action = q_values.argmax().item() - return action, rnn_states_actor + + return q_values, rnn_states_actor def get_critic_value_normalizer(self): return self.models["q_net"].value_normalizer diff --git a/openrl/runners/common/dqn_agent.py b/openrl/runners/common/dqn_agent.py index 4e9381f8..fb3a4949 100644 --- a/openrl/runners/common/dqn_agent.py +++ b/openrl/runners/common/dqn_agent.py @@ -71,6 +71,7 @@ def train(self: SelfAgent, total_time_steps: int) -> None: self._env.observation_space, self._env.action_space, data_client=None, + episode_length=self._cfg.buffer_size, ) logger = Logger(