Skip to content

Commit

Permalink
[dev4] 0.4.8
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjim0816 committed Jan 6, 2024
1 parent ed60ec7 commit b56bbb6
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 41 deletions.
21 changes: 11 additions & 10 deletions joyrl/framework/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-22 23:02:13
LastEditor: JiangJi
LastEditTime: 2024-01-04 22:59:23
LastEditTime: 2024-01-06 22:51:32
Discription:
'''
import ray
Expand All @@ -17,6 +17,7 @@
from joyrl.framework.config import MergedConfig
from joyrl.algos.base.data_handler import BaseDataHandler
from joyrl.framework.base import Moduler
from joyrl.utils.utils import exec_method

class Collector(Moduler):
''' Collector for collecting training data
Expand Down Expand Up @@ -48,10 +49,12 @@ def pub_msg(self, msg: Msg):
except Full:
self.logger.warning("[Collector.pub_msg] raw_exps_que is full!")
elif msg_type == MsgType.COLLECTOR_GET_TRAINING_DATA:
try:
return self._training_data_que.get(block = False)
except:
return None
return self._get_training_data()
# try:
# return self._training_data_que.get(block = False)
# except:
# # exec_method(self.logger, 'warning', True, "[Collector.pub_msg] training_data_que is empty!")
# return None
elif msg_type == MsgType.COLLECTOR_GET_BUFFER_LENGTH:
return self.get_buffer_length()
else:
Expand All @@ -69,15 +72,13 @@ def _prepare_training_data(self):
'''
while True:
training_data = self.data_handler.sample_training_data()
if training_data is None:
continue
else:
if training_data is not None:
try:
self._training_data_que.put(training_data, block = False)
except Full:
# self.logger.warning("[Collector._sample_training_data] training_data_que is full!")
# exec_method(self.logger, 'warning', True, "[Collector._prepare_training_data] training_data_que is full!")
pass

def _get_training_data(self):
training_data = self.data_handler.sample_training_data() # sample training data
return training_data
Expand Down
16 changes: 11 additions & 5 deletions joyrl/framework/interactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,18 @@
from joyrl.framework.recorder import Recorder
from joyrl.utils.utils import exec_method, create_module

class Interactor(Moduler):
class Interactor:
def __init__(self, cfg: MergedConfig, **kwargs) -> None:
super().__init__(cfg, **kwargs)
self.cfg = cfg
self.id = kwargs.get('id', 0)
self.env = kwargs.get('env', None)
self.policy = kwargs.get('policy', None)
self.tracker = kwargs['tracker']
self.collector = kwargs['collector']
self.recorder = kwargs['recorder']
self.policy_mgr = kwargs['policy_mgr']
self.logger = kwargs['logger']
self.use_ray = kwargs['use_ray']
self.seed = self.cfg.seed + self.id
self.exps = [] # reset experiences
self.summary = [] # reset summary
Expand All @@ -39,9 +41,9 @@ def run(self):
''' run in sync mode
'''
run_step = 0 # local run step
model_params = exec_method(self.policy_mgr, 'pub_msg', True, Msg(type = MsgType.MODEL_MGR_GET_MODEL_PARAMS)) # get model params
self.policy.put_model_params(model_params)
while True:
model_params = exec_method(self.policy_mgr, 'pub_msg', True, Msg(type = MsgType.MODEL_MGR_GET_MODEL_PARAMS)) # get model params
self.policy.put_model_params(model_params)
action = self.policy.get_action(self.curr_obs)
obs, reward, terminated, truncated, info = self.env.step(action)
interact_transition = {'interactor_id': self.id, 'state': self.curr_obs, 'action': action,'reward': reward, 'next_state': obs, 'done': terminated or truncated, 'info': info}
Expand Down Expand Up @@ -86,6 +88,8 @@ def __init__(self, cfg: MergedConfig, **kwargs) -> None:
collector = kwargs.get('collector', None),
recorder = self.recorder,
policy_mgr = kwargs.get('policy_mgr', None),
logger = self.logger,
use_ray = self.use_ray,
) for i in range(self.n_interactors)
]
exec_method(self.logger, 'info', True, f"[InteractorMgr] Create {self.n_interactors} interactors!")
Expand All @@ -94,7 +98,9 @@ def run(self):
''' run interactors
'''
for i in range(self.n_interactors):
exec_method(self.interactors[i], 'run', False)
self.interactors[i].run.remote()
# for i in range(self.n_interactors):
# exec_method(self.interactors[i], 'run', False)

def ray_run(self):
self.logger.info.remote(f"[InteractorMgr.run] Start interactors!")
Expand Down
43 changes: 26 additions & 17 deletions joyrl/framework/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-02 15:02:30
LastEditor: JiangJi
LastEditTime: 2024-01-04 23:48:37
LastEditTime: 2024-01-06 22:54:09
Discription:
'''
import ray
Expand All @@ -18,17 +18,19 @@
from joyrl.framework.recorder import Recorder
from joyrl.utils.utils import exec_method, create_module

class Learner(Moduler):
class Learner:
''' learner
'''
def __init__(self, cfg : MergedConfig, **kwargs) -> None:
super().__init__(cfg,**kwargs)
self.cfg = cfg
self.id = kwargs.get('id', 0)
self.policy = kwargs.get('policy', None)
self.policy_mgr = kwargs.get('policy_mgr', None)
self.collector = kwargs.get('collector', None)
self.tracker = kwargs.get('tracker', None)
self.recorder = kwargs.get('recorder', None)
self.logger = kwargs['logger']
self.use_ray = kwargs['use_ray']
self._init_update_steps()

def _init_update_steps(self):
Expand All @@ -40,19 +42,21 @@ def _init_update_steps(self):
def run(self):
run_step = 0
while True:
s_t = time.time()
training_data = exec_method(self.collector, 'pub_msg', True, Msg(type = MsgType.COLLECTOR_GET_TRAINING_DATA))
if training_data is None: return
self.policy.learn(**training_data)
global_update_step = exec_method(self.tracker, 'pub_msg', True, Msg(type = MsgType.TRACKER_GET_UPDATE_STEP))
exec_method(self.tracker, 'pub_msg', True, Msg(type = MsgType.TRACKER_INCREASE_UPDATE_STEP))
# put updated model params to policy_mgr
model_params = self.policy.get_model_params()
exec_method(self.policy_mgr, 'pub_msg', True, Msg(type = MsgType.MODEL_MGR_PUT_MODEL_PARAMS, data = (global_update_step, model_params)))
# put policy summary to recorder
if global_update_step % self.cfg.policy_summary_fre == 0:
policy_summary = [(global_update_step,self.policy.get_summary())]
exec_method(self.recorder, 'pub_msg', True, Msg(type = MsgType.RECORDER_PUT_SUMMARY, data = policy_summary))
run_step += 1
if training_data is not None:
self.policy.learn(**training_data)
global_update_step = exec_method(self.tracker, 'pub_msg', True, Msg(type = MsgType.TRACKER_GET_UPDATE_STEP))
exec_method(self.tracker, 'pub_msg', False, Msg(type = MsgType.TRACKER_INCREASE_UPDATE_STEP))
# put updated model params to policy_mgr
model_params = self.policy.get_model_params()
exec_method(self.policy_mgr, 'pub_msg', False, Msg(type = MsgType.MODEL_MGR_PUT_MODEL_PARAMS, data = (global_update_step, model_params)))
# put policy summary to recorder
if global_update_step % self.cfg.policy_summary_fre == 0:
policy_summary = [(global_update_step,self.policy.get_summary())]
exec_method(self.recorder, 'pub_msg', False, Msg(type = MsgType.RECORDER_PUT_SUMMARY, data = policy_summary))
run_step += 1
# exec_method(self.logger, 'info', True, f"Learner {self.id} finished {run_step} update steps in {time.time() - s_t:.4f}s!")
if run_step >= self.n_update_steps:
return

Expand All @@ -70,9 +74,14 @@ def __init__(self, cfg: MergedConfig, **kwargs) -> None:
policy_mgr = kwargs.get('policy_mgr', None),
collector = kwargs.get('collector', None),
tracker = kwargs.get('tracker', None),
recorder = self.recorder)
recorder = self.recorder,
logger = self.logger,
use_ray = self.use_ray,
)
for i in range(self.cfg.n_learners)]
exec_method(self.logger, 'info', True, f"[LearnerMgr] Create {self.cfg.n_learners} learners!")
def run(self):
for i in range(self.cfg.n_learners):
exec_method(self.learners[i], 'run', False)
self.learners[i].run.remote()
# for i in range(self.cfg.n_learners):
# exec_method(self.learners[i], 'run', False)
4 changes: 2 additions & 2 deletions joyrl/framework/policy_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-22 23:02:13
LastEditor: JiangJi
LastEditTime: 2024-01-04 23:52:45
LastEditTime: 2024-01-06 22:48:05
Discription:
'''
import time
Expand Down Expand Up @@ -65,7 +65,7 @@ def _put_model_params(self, msg_data):
self.logger.warning.remote(f"[PolicyMgr._put_model_params] saved_model_que is full!")
else:
self.logger.warning(f"[PolicyMgr._put_model_params] saved_model_que is full!")
time.sleep(0.001)
# time.sleep(0.001)

def _get_model_params(self):
''' get policy
Expand Down
4 changes: 2 additions & 2 deletions joyrl/framework/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-02 15:02:30
LastEditor: JiangJi
LastEditTime: 2024-01-04 23:52:09
LastEditTime: 2024-01-06 22:39:42
Discription:
'''
import time
Expand All @@ -32,7 +32,6 @@ def _print_cfgs(self):
'''
def print_cfg(cfg, name = ''):
cfg_dict = vars(cfg)
exec_method(self.logger, 'info', True, ''.join(['='] * 80))
exec_method(self.logger, 'info', True, f"{name}:")
exec_method(self.logger, 'info', True, ''.join(['='] * 80))
tplt = "{:^20}\t{:^20}\t{:^20}"
Expand All @@ -59,6 +58,7 @@ def run(self):
exec_method(self.interactor_mgr, 'run', False)
exec_method(self.learner_mgr, 'run', False)
while True:
time.sleep(0.1)
if exec_method(self.tracker, 'pub_msg', True, Msg(type = MsgType.TRACKER_CHECK_TASK_END)):
e_t = time.time()
exec_method(self.logger, 'info', True, f"[Trainer.run] Finish {self.cfg.mode}ing! Time cost: {e_t - s_t:.3f} s")
Expand Down
2 changes: 1 addition & 1 deletion joyrl/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-22 13:16:59
LastEditor: JiangJi
LastEditTime: 2023-12-24 19:00:32
LastEditTime: 2024-01-06 22:01:23
Discription:
'''
import sys,os
Expand Down
6 changes: 3 additions & 3 deletions offline_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-22 13:16:59
LastEditor: JiangJi
LastEditTime: 2024-01-04 23:02:39
LastEditTime: 2024-01-06 22:02:03
Discription:
'''
import sys,os
Expand Down Expand Up @@ -172,8 +172,8 @@ def run(self) -> None:
if self.cfg.n_interactors > 1:
is_remote = True
ray.init()
if self.cfg.online_eval:
online_tester = create_module(OnlineTester, False, {'num_cpus':0}, self.cfg, env = env, policy = policy)
# if self.cfg.online_eval:
# online_tester = create_module(OnlineTester, False, {'num_cpus':0}, self.cfg, env = env, policy = policy)
tracker = create_module(Tracker, is_remote, {'num_cpus':0}, self.cfg)
collector = create_module(Collector, is_remote, {'num_cpus':1}, self.cfg, data_handler = data_handler)
policy_mgr = create_module(PolicyMgr, is_remote, {'num_cpus':0}, self.cfg, policy = policy)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,5 @@ pygame==2.1.2
glfw==2.5.5
imageio==2.22.4
tensorboard==2.11.2
ray==2.6.3
ray[default]==2.6.3
gymnasium==0.28.1

0 comments on commit b56bbb6

Please sign in to comment.