-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9744ede
commit 76a7839
Showing
8 changed files
with
34 additions
and
34 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,13 +5,13 @@ | |
Email: [email protected] | ||
Date: 2023-01-01 16:20:49 | ||
LastEditor: JiangJi | ||
LastEditTime: 2023-12-25 01:01:05 | ||
LastEditTime: 2023-12-25 12:52:54 | ||
Discription: | ||
''' | ||
from joyrl import algos, framework, envs, utils | ||
from joyrl.run import run | ||
|
||
__version__ = "0.4.6.4" | ||
__version__ = "0.4.7" | ||
|
||
__all__ = [ | ||
"algos", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,10 +5,10 @@ | |
Email: [email protected] | ||
Date: 2023-02-20 21:53:39 | ||
LastEditor: JiangJi | ||
LastEditTime: 2023-05-25 22:19:00 | ||
LastEditTime: 2023-12-25 12:30:31 | ||
Discription: | ||
''' | ||
class AlgoConfig: | ||
class AlgoConfig(object): | ||
def __init__(self): | ||
self.independ_actor = True # whether to use independent actor | ||
# whether actor and critic share the same optimizer | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,11 +5,11 @@ | |
Email: [email protected] | ||
Date: 2023-05-17 01:08:36 | ||
LastEditor: JiangJi | ||
LastEditTime: 2023-05-17 13:42:25 | ||
LastEditTime: 2023-12-25 12:30:18 | ||
Discription: | ||
''' | ||
import numpy as np | ||
from algos.base.data_handlers import BaseDataHandler | ||
from joyrl.algos.base.data_handler import BaseDataHandler | ||
class DataHandler(BaseDataHandler): | ||
def __init__(self, cfg): | ||
super().__init__(cfg) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-22 23:02:13 | ||
LastEditor: JiangJi | ||
LastEditTime: 2023-12-24 20:11:46 | ||
LastEditTime: 2023-12-25 12:48:39 | ||
Discription: | ||
''' | ||
import torch | ||
|
@@ -15,7 +15,6 @@ | |
from torch.distributions import Categorical,Normal | ||
import torch.utils.data as Data | ||
import numpy as np | ||
|
||
from joyrl.algos.base.network import ValueNetwork, CriticNetwork, ActorNetwork | ||
from joyrl.algos.base.policy import BasePolicy | ||
|
||
|
@@ -90,11 +89,12 @@ def get_action(self, state, mode='sample', **kwargs): | |
if self.action_type.lower() == 'continuous': | ||
self.mu, self.sigma = self.actor(state) | ||
else: | ||
self.probs = self.actor(state) | ||
if mode == 'sample': | ||
output = self.actor(state) | ||
self.probs = output['probs'] | ||
if self.cfg.mode == 'train': | ||
action = self.sample_action(**kwargs) | ||
self.update_policy_transition() | ||
elif mode == 'predict': | ||
elif self.cfg.mode == 'test': | ||
action = self.predict_action(**kwargs) | ||
else: | ||
raise NameError('mode must be sample or predict') | ||
|
@@ -105,7 +105,6 @@ def update_policy_transition(self): | |
else: | ||
self.policy_transition = {'value': self.value, 'probs': self.probs, 'log_probs': self.log_probs} | ||
def sample_action(self,**kwargs): | ||
# sample_count = kwargs.get('sample_count', 0) | ||
if self.action_type.lower() == 'continuous': | ||
mean = self.mu * self.action_scale + self.action_bias | ||
std = self.sigma | ||
|
@@ -124,7 +123,7 @@ def predict_action(self, **kwargs): | |
return self.mu.detach().cpu().numpy()[0] | ||
else: | ||
return torch.argmax(self.probs).detach().cpu().numpy().item() | ||
def train(self, **kwargs): | ||
def learn(self, **kwargs): | ||
states, actions, next_states, rewards, dones = kwargs.get('states'), kwargs.get('actions'), kwargs.get('next_states'), kwargs.get('rewards'), kwargs.get('dones') | ||
if self.action_type.lower() == 'continuous': | ||
mus, sigmas = kwargs.get('mu'), kwargs.get('sigma') | ||
|
@@ -137,12 +136,12 @@ def train(self, **kwargs): | |
old_probs = torch.exp(old_log_probs) | ||
else: | ||
old_probs, old_log_probs = kwargs.get('probs'), kwargs.get('log_probs') | ||
old_probs = torch.cat(old_probs).to(self.device) # shape:[batch_size,n_actions] | ||
old_probs = torch.stack(old_probs, dim=0).to(device=self.device, dtype=torch.float32) # shape:[batch_size,n_actions] | ||
old_log_probs = torch.tensor(old_log_probs, device=self.device, dtype=torch.float32).unsqueeze(dim=1) # shape:[batch_size,1] | ||
# convert to tensor | ||
states = torch.tensor(np.array(states), device=self.device, dtype=torch.float32) # shape:[batch_size,n_states] | ||
# actions = torch.tensor(np.array(actions), device=self.device, dtype=torch.float32).unsqueeze(dim=1) # shape:[batch_size,1] | ||
actions = torch.tensor(np.array(actions), device=self.device, dtype=torch.float32) # shape:[batch_size,1] | ||
actions = torch.tensor(np.array(actions), device=self.device, dtype=torch.float32).unsqueeze(1) # shape:[batch_size,1] | ||
next_states = torch.tensor(np.array(next_states), device=self.device, dtype=torch.float32) # shape:[batch_size,n_states] | ||
rewards = torch.tensor(np.array(rewards), device=self.device, dtype=torch.float32) # shape:[batch_size,1] | ||
dones = torch.tensor(np.array(dones), device=self.device, dtype=torch.float32) # shape:[batch_size,1] | ||
|
@@ -162,10 +161,11 @@ def train(self, **kwargs): | |
dist = Normal(mean, std) | ||
new_log_probs = dist.log_prob(old_actions) | ||
else: | ||
new_probs = self.actor(old_states) # shape:[batch_size,n_actions] | ||
output = self.actor(old_states) | ||
new_probs = output['probs'] # shape:[batch_size,n_actions] | ||
dist = Categorical(new_probs) | ||
# get new action probabilities | ||
new_log_probs = dist.log_prob(old_actions.squeeze(dim=1)) # shape:[batch_size] | ||
new_log_probs = dist.log_prob(old_actions.squeeze(dim=1)).unsqueeze(dim=1) # shape:[batch_size,1] | ||
# compute ratio (pi_theta / pi_theta__old): | ||
ratio = torch.exp(new_log_probs - old_log_probs) # shape: [batch_size, 1] | ||
# compute surrogate loss | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,12 +5,13 @@ | |
Email: [email protected] | ||
Date: 2023-01-01 16:20:49 | ||
LastEditor: JiangJi | ||
LastEditTime: 2023-05-30 23:57:42 | ||
LastEditTime: 2023-12-25 12:52:47 | ||
Discription: | ||
''' | ||
from joyrl.algos import base,DQN,DoubleDQN,DuelingDQN,NoisyDQN,PPO | ||
__all__ = [ | ||
"base", | ||
"QLearning", | ||
"DQN", | ||
"DoubleDQN", | ||
"DuelingDQN", | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters