From 76a783938a0da658885598793c75fa4aecec17c2 Mon Sep 17 00:00:00 2001 From: johnjim0816 Date: Mon, 25 Dec 2023 12:53:38 +0800 Subject: [PATCH] [0.4.7] update PPO --- README.md | 1 + docs/README.md | 1 + joyrl/__init__.py | 4 +-- joyrl/algos/PPO/config.py | 4 +-- joyrl/algos/PPO/data_handler.py | 4 +-- joyrl/algos/PPO/policy.py | 22 +++++++------- joyrl/algos/__init__.py | 3 +- .../CartPole-v1/CartPole-v1_PPO.yaml | 29 +++++++++---------- 8 files changed, 34 insertions(+), 34 deletions(-) diff --git a/README.md b/README.md index a86871e..817257e 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,7 @@ More tutorials and API documentation are hosted on [JoyRL docs](https://datawhal | NoisyDQN | [NoisyDQN Paper](https://arxiv.org/pdf/1706.10295.pdf) | [johnjim0816](https://github.com/johnjim0816) | | | DDPG | [DDPG Paper](https://arxiv.org/abs/1509.02971) | [johnjim0816](https://github.com/johnjim0816) | | | TD3 | [TD3 Paper](https://arxiv.org/pdf/1802.09477) | [johnjim0816](https://github.com/johnjim0816) | | +| PPO | [PPO Paper](https://arxiv.org/abs/1707.06347) | [johnjim0816](https://github.com/johnjim0816) | | ## Why JoyRL? diff --git a/docs/README.md b/docs/README.md index a86871e..817257e 100644 --- a/docs/README.md +++ b/docs/README.md @@ -86,6 +86,7 @@ More tutorials and API documentation are hosted on [JoyRL docs](https://datawhal | NoisyDQN | [NoisyDQN Paper](https://arxiv.org/pdf/1706.10295.pdf) | [johnjim0816](https://github.com/johnjim0816) | | | DDPG | [DDPG Paper](https://arxiv.org/abs/1509.02971) | [johnjim0816](https://github.com/johnjim0816) | | | TD3 | [TD3 Paper](https://arxiv.org/pdf/1802.09477) | [johnjim0816](https://github.com/johnjim0816) | | +| PPO | [PPO Paper](https://arxiv.org/abs/1707.06347) | [johnjim0816](https://github.com/johnjim0816) | | ## Why JoyRL? diff --git a/joyrl/__init__.py b/joyrl/__init__.py index 633cc82..006d8f4 100644 --- a/joyrl/__init__.py +++ b/joyrl/__init__.py @@ -5,13 +5,13 @@ Email: johnjim0816@gmail.com Date: 2023-01-01 16:20:49 LastEditor: JiangJi -LastEditTime: 2023-12-25 01:01:05 +LastEditTime: 2023-12-25 12:52:54 Discription: ''' from joyrl import algos, framework, envs, utils from joyrl.run import run -__version__ = "0.4.6.4" +__version__ = "0.4.7" __all__ = [ "algos", diff --git a/joyrl/algos/PPO/config.py b/joyrl/algos/PPO/config.py index 208b034..b8917a9 100644 --- a/joyrl/algos/PPO/config.py +++ b/joyrl/algos/PPO/config.py @@ -5,10 +5,10 @@ Email: johnjim0816@gmail.com Date: 2023-02-20 21:53:39 LastEditor: JiangJi -LastEditTime: 2023-05-25 22:19:00 +LastEditTime: 2023-12-25 12:30:31 Discription: ''' -class AlgoConfig: +class AlgoConfig(object): def __init__(self): self.independ_actor = True # whether to use independent actor # whether actor and critic share the same optimizer diff --git a/joyrl/algos/PPO/data_handler.py b/joyrl/algos/PPO/data_handler.py index 763074f..1175f3e 100644 --- a/joyrl/algos/PPO/data_handler.py +++ b/joyrl/algos/PPO/data_handler.py @@ -5,11 +5,11 @@ Email: johnjim0816@gmail.com Date: 2023-05-17 01:08:36 LastEditor: JiangJi -LastEditTime: 2023-05-17 13:42:25 +LastEditTime: 2023-12-25 12:30:18 Discription: ''' import numpy as np -from algos.base.data_handlers import BaseDataHandler +from joyrl.algos.base.data_handler import BaseDataHandler class DataHandler(BaseDataHandler): def __init__(self, cfg): super().__init__(cfg) diff --git a/joyrl/algos/PPO/policy.py b/joyrl/algos/PPO/policy.py index 5f8646a..9989cd9 100644 --- a/joyrl/algos/PPO/policy.py +++ b/joyrl/algos/PPO/policy.py @@ -5,7 +5,7 @@ Email: johnjim0816@gmail.com Date: 2023-12-22 23:02:13 LastEditor: JiangJi -LastEditTime: 2023-12-24 20:11:46 +LastEditTime: 2023-12-25 12:48:39 Discription: ''' import torch @@ -15,7 +15,6 @@ from torch.distributions import Categorical,Normal import torch.utils.data as Data import numpy as np - from joyrl.algos.base.network import ValueNetwork, CriticNetwork, ActorNetwork from joyrl.algos.base.policy import BasePolicy @@ -90,11 +89,12 @@ def get_action(self, state, mode='sample', **kwargs): if self.action_type.lower() == 'continuous': self.mu, self.sigma = self.actor(state) else: - self.probs = self.actor(state) - if mode == 'sample': + output = self.actor(state) + self.probs = output['probs'] + if self.cfg.mode == 'train': action = self.sample_action(**kwargs) self.update_policy_transition() - elif mode == 'predict': + elif self.cfg.mode == 'test': action = self.predict_action(**kwargs) else: raise NameError('mode must be sample or predict') @@ -105,7 +105,6 @@ def update_policy_transition(self): else: self.policy_transition = {'value': self.value, 'probs': self.probs, 'log_probs': self.log_probs} def sample_action(self,**kwargs): - # sample_count = kwargs.get('sample_count', 0) if self.action_type.lower() == 'continuous': mean = self.mu * self.action_scale + self.action_bias std = self.sigma @@ -124,7 +123,7 @@ def predict_action(self, **kwargs): return self.mu.detach().cpu().numpy()[0] else: return torch.argmax(self.probs).detach().cpu().numpy().item() - def train(self, **kwargs): + def learn(self, **kwargs): states, actions, next_states, rewards, dones = kwargs.get('states'), kwargs.get('actions'), kwargs.get('next_states'), kwargs.get('rewards'), kwargs.get('dones') if self.action_type.lower() == 'continuous': mus, sigmas = kwargs.get('mu'), kwargs.get('sigma') @@ -137,12 +136,12 @@ def train(self, **kwargs): old_probs = torch.exp(old_log_probs) else: old_probs, old_log_probs = kwargs.get('probs'), kwargs.get('log_probs') - old_probs = torch.cat(old_probs).to(self.device) # shape:[batch_size,n_actions] + old_probs = torch.stack(old_probs, dim=0).to(device=self.device, dtype=torch.float32) # shape:[batch_size,n_actions] old_log_probs = torch.tensor(old_log_probs, device=self.device, dtype=torch.float32).unsqueeze(dim=1) # shape:[batch_size,1] # convert to tensor states = torch.tensor(np.array(states), device=self.device, dtype=torch.float32) # shape:[batch_size,n_states] # actions = torch.tensor(np.array(actions), device=self.device, dtype=torch.float32).unsqueeze(dim=1) # shape:[batch_size,1] - actions = torch.tensor(np.array(actions), device=self.device, dtype=torch.float32) # shape:[batch_size,1] + actions = torch.tensor(np.array(actions), device=self.device, dtype=torch.float32).unsqueeze(1) # shape:[batch_size,1] next_states = torch.tensor(np.array(next_states), device=self.device, dtype=torch.float32) # shape:[batch_size,n_states] rewards = torch.tensor(np.array(rewards), device=self.device, dtype=torch.float32) # shape:[batch_size,1] dones = torch.tensor(np.array(dones), device=self.device, dtype=torch.float32) # shape:[batch_size,1] @@ -162,10 +161,11 @@ def train(self, **kwargs): dist = Normal(mean, std) new_log_probs = dist.log_prob(old_actions) else: - new_probs = self.actor(old_states) # shape:[batch_size,n_actions] + output = self.actor(old_states) + new_probs = output['probs'] # shape:[batch_size,n_actions] dist = Categorical(new_probs) # get new action probabilities - new_log_probs = dist.log_prob(old_actions.squeeze(dim=1)) # shape:[batch_size] + new_log_probs = dist.log_prob(old_actions.squeeze(dim=1)).unsqueeze(dim=1) # shape:[batch_size,1] # compute ratio (pi_theta / pi_theta__old): ratio = torch.exp(new_log_probs - old_log_probs) # shape: [batch_size, 1] # compute surrogate loss diff --git a/joyrl/algos/__init__.py b/joyrl/algos/__init__.py index 8900155..0fd351b 100644 --- a/joyrl/algos/__init__.py +++ b/joyrl/algos/__init__.py @@ -5,12 +5,13 @@ Email: johnjim0816@gmail.com Date: 2023-01-01 16:20:49 LastEditor: JiangJi -LastEditTime: 2023-05-30 23:57:42 +LastEditTime: 2023-12-25 12:52:47 Discription: ''' from joyrl.algos import base,DQN,DoubleDQN,DuelingDQN,NoisyDQN,PPO __all__ = [ "base", + "QLearning", "DQN", "DoubleDQN", "DuelingDQN", diff --git a/presets/ClassControl/CartPole-v1/CartPole-v1_PPO.yaml b/presets/ClassControl/CartPole-v1/CartPole-v1_PPO.yaml index 5f2fe4e..137b8fa 100644 --- a/presets/ClassControl/CartPole-v1/CartPole-v1_PPO.yaml +++ b/presets/ClassControl/CartPole-v1/CartPole-v1_PPO.yaml @@ -1,20 +1,17 @@ general_cfg: algo_name: PPO - env_name: gym # env name, differ from env_id in env_cfgs - device: cpu # device, cpu or cuda - mode: train # run mode: train, test - collect_traj: false # if collect trajectories or not - mp_backend: single # multi-processing mode: single(default), ray - n_workers: 2 # number of workers if using multi-processing, default 1 - load_checkpoint: false # if load checkpoint or not - load_path: Train_single_CartPole-v1_DQN_20230515-211721 # if load checkpoint, then config path in 'tasks' dir - load_model_step: best # load model step - max_episode: 200 # max episodes, set -1 to keep running - max_step: 200 # max steps per episode - seed: 1 # random seed, set 0 not to use seed - online_eval: true # if online eval or not - online_eval_episode: 10 # online eval episodes - model_save_fre: 10 # update step frequency of saving model + env_name: gym + device: cpu + mode: train + load_checkpoint: false + load_path: Train_CartPole-v1_PPO_20231225-124842 # if load checkpoint, then config path in 'tasks' dir + load_model_step: best + max_episode: -1 + max_step: 200 + seed: 1 + online_eval: true + online_eval_episode: 10 + model_save_fre: 10 algo_cfg: ppo_type: clip action_type: discrete @@ -42,7 +39,7 @@ algo_cfg: gamma: 0.99 k_epochs: 4 batch_size: 256 - sgd_batch_size: 128 + sgd_batch_size: 256 env_cfg: id: CartPole-v1 render_mode: null