diff --git a/joyrl/framework/collector.py b/joyrl/framework/collector.py index 2b8c8b5..4fde246 100644 --- a/joyrl/framework/collector.py +++ b/joyrl/framework/collector.py @@ -5,7 +5,7 @@ Email: johnjim0816@gmail.com Date: 2023-12-22 23:02:13 LastEditor: JiangJi -LastEditTime: 2024-01-04 22:59:23 +LastEditTime: 2024-01-06 22:51:32 Discription: ''' import ray @@ -17,6 +17,7 @@ from joyrl.framework.config import MergedConfig from joyrl.algos.base.data_handler import BaseDataHandler from joyrl.framework.base import Moduler +from joyrl.utils.utils import exec_method class Collector(Moduler): ''' Collector for collecting training data @@ -48,10 +49,12 @@ def pub_msg(self, msg: Msg): except Full: self.logger.warning("[Collector.pub_msg] raw_exps_que is full!") elif msg_type == MsgType.COLLECTOR_GET_TRAINING_DATA: - try: - return self._training_data_que.get(block = False) - except: - return None + return self._get_training_data() + # try: + # return self._training_data_que.get(block = False) + # except: + # # exec_method(self.logger, 'warning', True, "[Collector.pub_msg] training_data_que is empty!") + # return None elif msg_type == MsgType.COLLECTOR_GET_BUFFER_LENGTH: return self.get_buffer_length() else: @@ -69,15 +72,13 @@ def _prepare_training_data(self): ''' while True: training_data = self.data_handler.sample_training_data() - if training_data is None: - continue - else: + if training_data is not None: try: self._training_data_que.put(training_data, block = False) except Full: - # self.logger.warning("[Collector._sample_training_data] training_data_que is full!") + # exec_method(self.logger, 'warning', True, "[Collector._prepare_training_data] training_data_que is full!") pass - + def _get_training_data(self): training_data = self.data_handler.sample_training_data() # sample training data return training_data diff --git a/joyrl/framework/interactor.py b/joyrl/framework/interactor.py index e348043..5516ebf 100644 --- a/joyrl/framework/interactor.py +++ b/joyrl/framework/interactor.py @@ -8,9 +8,9 @@ from joyrl.framework.recorder import Recorder from joyrl.utils.utils import exec_method, create_module -class Interactor(Moduler): +class Interactor: def __init__(self, cfg: MergedConfig, **kwargs) -> None: - super().__init__(cfg, **kwargs) + self.cfg = cfg self.id = kwargs.get('id', 0) self.env = kwargs.get('env', None) self.policy = kwargs.get('policy', None) @@ -18,6 +18,8 @@ def __init__(self, cfg: MergedConfig, **kwargs) -> None: self.collector = kwargs['collector'] self.recorder = kwargs['recorder'] self.policy_mgr = kwargs['policy_mgr'] + self.logger = kwargs['logger'] + self.use_ray = kwargs['use_ray'] self.seed = self.cfg.seed + self.id self.exps = [] # reset experiences self.summary = [] # reset summary @@ -39,9 +41,9 @@ def run(self): ''' run in sync mode ''' run_step = 0 # local run step - model_params = exec_method(self.policy_mgr, 'pub_msg', True, Msg(type = MsgType.MODEL_MGR_GET_MODEL_PARAMS)) # get model params - self.policy.put_model_params(model_params) while True: + model_params = exec_method(self.policy_mgr, 'pub_msg', True, Msg(type = MsgType.MODEL_MGR_GET_MODEL_PARAMS)) # get model params + self.policy.put_model_params(model_params) action = self.policy.get_action(self.curr_obs) obs, reward, terminated, truncated, info = self.env.step(action) interact_transition = {'interactor_id': self.id, 'state': self.curr_obs, 'action': action,'reward': reward, 'next_state': obs, 'done': terminated or truncated, 'info': info} @@ -86,6 +88,8 @@ def __init__(self, cfg: MergedConfig, **kwargs) -> None: collector = kwargs.get('collector', None), recorder = self.recorder, policy_mgr = kwargs.get('policy_mgr', None), + logger = self.logger, + use_ray = self.use_ray, ) for i in range(self.n_interactors) ] exec_method(self.logger, 'info', True, f"[InteractorMgr] Create {self.n_interactors} interactors!") @@ -94,7 +98,9 @@ def run(self): ''' run interactors ''' for i in range(self.n_interactors): - exec_method(self.interactors[i], 'run', False) + self.interactors[i].run.remote() + # for i in range(self.n_interactors): + # exec_method(self.interactors[i], 'run', False) def ray_run(self): self.logger.info.remote(f"[InteractorMgr.run] Start interactors!") diff --git a/joyrl/framework/learner.py b/joyrl/framework/learner.py index 080a193..6694667 100644 --- a/joyrl/framework/learner.py +++ b/joyrl/framework/learner.py @@ -5,7 +5,7 @@ Email: johnjim0816@gmail.com Date: 2023-12-02 15:02:30 LastEditor: JiangJi -LastEditTime: 2024-01-04 23:48:37 +LastEditTime: 2024-01-06 22:54:09 Discription: ''' import ray @@ -18,17 +18,19 @@ from joyrl.framework.recorder import Recorder from joyrl.utils.utils import exec_method, create_module -class Learner(Moduler): +class Learner: ''' learner ''' def __init__(self, cfg : MergedConfig, **kwargs) -> None: - super().__init__(cfg,**kwargs) + self.cfg = cfg self.id = kwargs.get('id', 0) self.policy = kwargs.get('policy', None) self.policy_mgr = kwargs.get('policy_mgr', None) self.collector = kwargs.get('collector', None) self.tracker = kwargs.get('tracker', None) self.recorder = kwargs.get('recorder', None) + self.logger = kwargs['logger'] + self.use_ray = kwargs['use_ray'] self._init_update_steps() def _init_update_steps(self): @@ -40,19 +42,21 @@ def _init_update_steps(self): def run(self): run_step = 0 while True: + s_t = time.time() training_data = exec_method(self.collector, 'pub_msg', True, Msg(type = MsgType.COLLECTOR_GET_TRAINING_DATA)) - if training_data is None: return - self.policy.learn(**training_data) - global_update_step = exec_method(self.tracker, 'pub_msg', True, Msg(type = MsgType.TRACKER_GET_UPDATE_STEP)) - exec_method(self.tracker, 'pub_msg', True, Msg(type = MsgType.TRACKER_INCREASE_UPDATE_STEP)) - # put updated model params to policy_mgr - model_params = self.policy.get_model_params() - exec_method(self.policy_mgr, 'pub_msg', True, Msg(type = MsgType.MODEL_MGR_PUT_MODEL_PARAMS, data = (global_update_step, model_params))) - # put policy summary to recorder - if global_update_step % self.cfg.policy_summary_fre == 0: - policy_summary = [(global_update_step,self.policy.get_summary())] - exec_method(self.recorder, 'pub_msg', True, Msg(type = MsgType.RECORDER_PUT_SUMMARY, data = policy_summary)) - run_step += 1 + if training_data is not None: + self.policy.learn(**training_data) + global_update_step = exec_method(self.tracker, 'pub_msg', True, Msg(type = MsgType.TRACKER_GET_UPDATE_STEP)) + exec_method(self.tracker, 'pub_msg', False, Msg(type = MsgType.TRACKER_INCREASE_UPDATE_STEP)) + # put updated model params to policy_mgr + model_params = self.policy.get_model_params() + exec_method(self.policy_mgr, 'pub_msg', False, Msg(type = MsgType.MODEL_MGR_PUT_MODEL_PARAMS, data = (global_update_step, model_params))) + # put policy summary to recorder + if global_update_step % self.cfg.policy_summary_fre == 0: + policy_summary = [(global_update_step,self.policy.get_summary())] + exec_method(self.recorder, 'pub_msg', False, Msg(type = MsgType.RECORDER_PUT_SUMMARY, data = policy_summary)) + run_step += 1 + # exec_method(self.logger, 'info', True, f"Learner {self.id} finished {run_step} update steps in {time.time() - s_t:.4f}s!") if run_step >= self.n_update_steps: return @@ -70,9 +74,14 @@ def __init__(self, cfg: MergedConfig, **kwargs) -> None: policy_mgr = kwargs.get('policy_mgr', None), collector = kwargs.get('collector', None), tracker = kwargs.get('tracker', None), - recorder = self.recorder) + recorder = self.recorder, + logger = self.logger, + use_ray = self.use_ray, + ) for i in range(self.cfg.n_learners)] exec_method(self.logger, 'info', True, f"[LearnerMgr] Create {self.cfg.n_learners} learners!") def run(self): for i in range(self.cfg.n_learners): - exec_method(self.learners[i], 'run', False) + self.learners[i].run.remote() + # for i in range(self.cfg.n_learners): + # exec_method(self.learners[i], 'run', False) diff --git a/joyrl/framework/policy_mgr.py b/joyrl/framework/policy_mgr.py index 26ff4de..2c4536c 100644 --- a/joyrl/framework/policy_mgr.py +++ b/joyrl/framework/policy_mgr.py @@ -5,7 +5,7 @@ Email: johnjim0816@gmail.com Date: 2023-12-22 23:02:13 LastEditor: JiangJi -LastEditTime: 2024-01-04 23:52:45 +LastEditTime: 2024-01-06 22:48:05 Discription: ''' import time @@ -65,7 +65,7 @@ def _put_model_params(self, msg_data): self.logger.warning.remote(f"[PolicyMgr._put_model_params] saved_model_que is full!") else: self.logger.warning(f"[PolicyMgr._put_model_params] saved_model_que is full!") - time.sleep(0.001) + # time.sleep(0.001) def _get_model_params(self): ''' get policy diff --git a/joyrl/framework/trainer.py b/joyrl/framework/trainer.py index a03335c..dd08d57 100644 --- a/joyrl/framework/trainer.py +++ b/joyrl/framework/trainer.py @@ -5,7 +5,7 @@ Email: johnjim0816@gmail.com Date: 2023-12-02 15:02:30 LastEditor: JiangJi -LastEditTime: 2024-01-04 23:52:09 +LastEditTime: 2024-01-06 22:39:42 Discription: ''' import time @@ -32,7 +32,6 @@ def _print_cfgs(self): ''' def print_cfg(cfg, name = ''): cfg_dict = vars(cfg) - exec_method(self.logger, 'info', True, ''.join(['='] * 80)) exec_method(self.logger, 'info', True, f"{name}:") exec_method(self.logger, 'info', True, ''.join(['='] * 80)) tplt = "{:^20}\t{:^20}\t{:^20}" @@ -59,6 +58,7 @@ def run(self): exec_method(self.interactor_mgr, 'run', False) exec_method(self.learner_mgr, 'run', False) while True: + time.sleep(0.1) if exec_method(self.tracker, 'pub_msg', True, Msg(type = MsgType.TRACKER_CHECK_TASK_END)): e_t = time.time() exec_method(self.logger, 'info', True, f"[Trainer.run] Finish {self.cfg.mode}ing! Time cost: {e_t - s_t:.3f} s") diff --git a/joyrl/run.py b/joyrl/run.py index 3eb3095..cb6cc06 100644 --- a/joyrl/run.py +++ b/joyrl/run.py @@ -5,7 +5,7 @@ Email: johnjim0816@gmail.com Date: 2023-12-22 13:16:59 LastEditor: JiangJi -LastEditTime: 2023-12-24 19:00:32 +LastEditTime: 2024-01-06 22:01:23 Discription: ''' import sys,os diff --git a/offline_run.py b/offline_run.py index 228209e..df15918 100644 --- a/offline_run.py +++ b/offline_run.py @@ -5,7 +5,7 @@ Email: johnjim0816@gmail.com Date: 2023-12-22 13:16:59 LastEditor: JiangJi -LastEditTime: 2024-01-04 23:02:39 +LastEditTime: 2024-01-06 22:02:03 Discription: ''' import sys,os @@ -172,8 +172,8 @@ def run(self) -> None: if self.cfg.n_interactors > 1: is_remote = True ray.init() - if self.cfg.online_eval: - online_tester = create_module(OnlineTester, False, {'num_cpus':0}, self.cfg, env = env, policy = policy) + # if self.cfg.online_eval: + # online_tester = create_module(OnlineTester, False, {'num_cpus':0}, self.cfg, env = env, policy = policy) tracker = create_module(Tracker, is_remote, {'num_cpus':0}, self.cfg) collector = create_module(Collector, is_remote, {'num_cpus':1}, self.cfg, data_handler = data_handler) policy_mgr = create_module(PolicyMgr, is_remote, {'num_cpus':0}, self.cfg, policy = policy) diff --git a/requirements.txt b/requirements.txt index 39cc524..d1fceea 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,5 +12,5 @@ pygame==2.1.2 glfw==2.5.5 imageio==2.22.4 tensorboard==2.11.2 -ray==2.6.3 +ray[default]==2.6.3 gymnasium==0.28.1 \ No newline at end of file