Skip to content

Commit

Permalink
[0.4.7] update PPO
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjim0816 committed Dec 25, 2023
1 parent 9744ede commit 76a7839
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 34 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ More tutorials and API documentation are hosted on [JoyRL docs](https://datawhal
| NoisyDQN | [NoisyDQN Paper](https://arxiv.org/pdf/1706.10295.pdf) | [johnjim0816](https://github.com/johnjim0816) | |
| DDPG | [DDPG Paper](https://arxiv.org/abs/1509.02971) | [johnjim0816](https://github.com/johnjim0816) | |
| TD3 | [TD3 Paper](https://arxiv.org/pdf/1802.09477) | [johnjim0816](https://github.com/johnjim0816) | |
| PPO | [PPO Paper](https://arxiv.org/abs/1707.06347) | [johnjim0816](https://github.com/johnjim0816) | |

## Why JoyRL?

Expand Down
1 change: 1 addition & 0 deletions docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ More tutorials and API documentation are hosted on [JoyRL docs](https://datawhal
| NoisyDQN | [NoisyDQN Paper](https://arxiv.org/pdf/1706.10295.pdf) | [johnjim0816](https://github.com/johnjim0816) | |
| DDPG | [DDPG Paper](https://arxiv.org/abs/1509.02971) | [johnjim0816](https://github.com/johnjim0816) | |
| TD3 | [TD3 Paper](https://arxiv.org/pdf/1802.09477) | [johnjim0816](https://github.com/johnjim0816) | |
| PPO | [PPO Paper](https://arxiv.org/abs/1707.06347) | [johnjim0816](https://github.com/johnjim0816) | |

## Why JoyRL?

Expand Down
4 changes: 2 additions & 2 deletions joyrl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
Email: [email protected]
Date: 2023-01-01 16:20:49
LastEditor: JiangJi
LastEditTime: 2023-12-25 01:01:05
LastEditTime: 2023-12-25 12:52:54
Discription:
'''
from joyrl import algos, framework, envs, utils
from joyrl.run import run

__version__ = "0.4.6.4"
__version__ = "0.4.7"

__all__ = [
"algos",
Expand Down
4 changes: 2 additions & 2 deletions joyrl/algos/PPO/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
Email: [email protected]
Date: 2023-02-20 21:53:39
LastEditor: JiangJi
LastEditTime: 2023-05-25 22:19:00
LastEditTime: 2023-12-25 12:30:31
Discription:
'''
class AlgoConfig:
class AlgoConfig(object):
def __init__(self):
self.independ_actor = True # whether to use independent actor
# whether actor and critic share the same optimizer
Expand Down
4 changes: 2 additions & 2 deletions joyrl/algos/PPO/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
Email: [email protected]
Date: 2023-05-17 01:08:36
LastEditor: JiangJi
LastEditTime: 2023-05-17 13:42:25
LastEditTime: 2023-12-25 12:30:18
Discription:
'''
import numpy as np
from algos.base.data_handlers import BaseDataHandler
from joyrl.algos.base.data_handler import BaseDataHandler
class DataHandler(BaseDataHandler):
def __init__(self, cfg):
super().__init__(cfg)
Expand Down
22 changes: 11 additions & 11 deletions joyrl/algos/PPO/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-22 23:02:13
LastEditor: JiangJi
LastEditTime: 2023-12-24 20:11:46
LastEditTime: 2023-12-25 12:48:39
Discription:
'''
import torch
Expand All @@ -15,7 +15,6 @@
from torch.distributions import Categorical,Normal
import torch.utils.data as Data
import numpy as np

from joyrl.algos.base.network import ValueNetwork, CriticNetwork, ActorNetwork
from joyrl.algos.base.policy import BasePolicy

Expand Down Expand Up @@ -90,11 +89,12 @@ def get_action(self, state, mode='sample', **kwargs):
if self.action_type.lower() == 'continuous':
self.mu, self.sigma = self.actor(state)
else:
self.probs = self.actor(state)
if mode == 'sample':
output = self.actor(state)
self.probs = output['probs']
if self.cfg.mode == 'train':
action = self.sample_action(**kwargs)
self.update_policy_transition()
elif mode == 'predict':
elif self.cfg.mode == 'test':
action = self.predict_action(**kwargs)
else:
raise NameError('mode must be sample or predict')
Expand All @@ -105,7 +105,6 @@ def update_policy_transition(self):
else:
self.policy_transition = {'value': self.value, 'probs': self.probs, 'log_probs': self.log_probs}
def sample_action(self,**kwargs):
# sample_count = kwargs.get('sample_count', 0)
if self.action_type.lower() == 'continuous':
mean = self.mu * self.action_scale + self.action_bias
std = self.sigma
Expand All @@ -124,7 +123,7 @@ def predict_action(self, **kwargs):
return self.mu.detach().cpu().numpy()[0]
else:
return torch.argmax(self.probs).detach().cpu().numpy().item()
def train(self, **kwargs):
def learn(self, **kwargs):
states, actions, next_states, rewards, dones = kwargs.get('states'), kwargs.get('actions'), kwargs.get('next_states'), kwargs.get('rewards'), kwargs.get('dones')
if self.action_type.lower() == 'continuous':
mus, sigmas = kwargs.get('mu'), kwargs.get('sigma')
Expand All @@ -137,12 +136,12 @@ def train(self, **kwargs):
old_probs = torch.exp(old_log_probs)
else:
old_probs, old_log_probs = kwargs.get('probs'), kwargs.get('log_probs')
old_probs = torch.cat(old_probs).to(self.device) # shape:[batch_size,n_actions]
old_probs = torch.stack(old_probs, dim=0).to(device=self.device, dtype=torch.float32) # shape:[batch_size,n_actions]
old_log_probs = torch.tensor(old_log_probs, device=self.device, dtype=torch.float32).unsqueeze(dim=1) # shape:[batch_size,1]
# convert to tensor
states = torch.tensor(np.array(states), device=self.device, dtype=torch.float32) # shape:[batch_size,n_states]
# actions = torch.tensor(np.array(actions), device=self.device, dtype=torch.float32).unsqueeze(dim=1) # shape:[batch_size,1]
actions = torch.tensor(np.array(actions), device=self.device, dtype=torch.float32) # shape:[batch_size,1]
actions = torch.tensor(np.array(actions), device=self.device, dtype=torch.float32).unsqueeze(1) # shape:[batch_size,1]
next_states = torch.tensor(np.array(next_states), device=self.device, dtype=torch.float32) # shape:[batch_size,n_states]
rewards = torch.tensor(np.array(rewards), device=self.device, dtype=torch.float32) # shape:[batch_size,1]
dones = torch.tensor(np.array(dones), device=self.device, dtype=torch.float32) # shape:[batch_size,1]
Expand All @@ -162,10 +161,11 @@ def train(self, **kwargs):
dist = Normal(mean, std)
new_log_probs = dist.log_prob(old_actions)
else:
new_probs = self.actor(old_states) # shape:[batch_size,n_actions]
output = self.actor(old_states)
new_probs = output['probs'] # shape:[batch_size,n_actions]
dist = Categorical(new_probs)
# get new action probabilities
new_log_probs = dist.log_prob(old_actions.squeeze(dim=1)) # shape:[batch_size]
new_log_probs = dist.log_prob(old_actions.squeeze(dim=1)).unsqueeze(dim=1) # shape:[batch_size,1]
# compute ratio (pi_theta / pi_theta__old):
ratio = torch.exp(new_log_probs - old_log_probs) # shape: [batch_size, 1]
# compute surrogate loss
Expand Down
3 changes: 2 additions & 1 deletion joyrl/algos/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
Email: [email protected]
Date: 2023-01-01 16:20:49
LastEditor: JiangJi
LastEditTime: 2023-05-30 23:57:42
LastEditTime: 2023-12-25 12:52:47
Discription:
'''
from joyrl.algos import base,DQN,DoubleDQN,DuelingDQN,NoisyDQN,PPO
__all__ = [
"base",
"QLearning",
"DQN",
"DoubleDQN",
"DuelingDQN",
Expand Down
29 changes: 13 additions & 16 deletions presets/ClassControl/CartPole-v1/CartPole-v1_PPO.yaml
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
general_cfg:
algo_name: PPO
env_name: gym # env name, differ from env_id in env_cfgs
device: cpu # device, cpu or cuda
mode: train # run mode: train, test
collect_traj: false # if collect trajectories or not
mp_backend: single # multi-processing mode: single(default), ray
n_workers: 2 # number of workers if using multi-processing, default 1
load_checkpoint: false # if load checkpoint or not
load_path: Train_single_CartPole-v1_DQN_20230515-211721 # if load checkpoint, then config path in 'tasks' dir
load_model_step: best # load model step
max_episode: 200 # max episodes, set -1 to keep running
max_step: 200 # max steps per episode
seed: 1 # random seed, set 0 not to use seed
online_eval: true # if online eval or not
online_eval_episode: 10 # online eval episodes
model_save_fre: 10 # update step frequency of saving model
env_name: gym
device: cpu
mode: train
load_checkpoint: false
load_path: Train_CartPole-v1_PPO_20231225-124842 # if load checkpoint, then config path in 'tasks' dir
load_model_step: best
max_episode: -1
max_step: 200
seed: 1
online_eval: true
online_eval_episode: 10
model_save_fre: 10
algo_cfg:
ppo_type: clip
action_type: discrete
Expand Down Expand Up @@ -42,7 +39,7 @@ algo_cfg:
gamma: 0.99
k_epochs: 4
batch_size: 256
sgd_batch_size: 128
sgd_batch_size: 256
env_cfg:
id: CartPole-v1
render_mode: null
Expand Down

0 comments on commit 76a7839

Please sign in to comment.