-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
ed60ec7
commit b56bbb6
Showing
8 changed files
with
57 additions
and
41 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-22 23:02:13 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-01-04 22:59:23 | ||
LastEditTime: 2024-01-06 22:51:32 | ||
Discription: | ||
''' | ||
import ray | ||
|
@@ -17,6 +17,7 @@ | |
from joyrl.framework.config import MergedConfig | ||
from joyrl.algos.base.data_handler import BaseDataHandler | ||
from joyrl.framework.base import Moduler | ||
from joyrl.utils.utils import exec_method | ||
|
||
class Collector(Moduler): | ||
''' Collector for collecting training data | ||
|
@@ -48,10 +49,12 @@ def pub_msg(self, msg: Msg): | |
except Full: | ||
self.logger.warning("[Collector.pub_msg] raw_exps_que is full!") | ||
elif msg_type == MsgType.COLLECTOR_GET_TRAINING_DATA: | ||
try: | ||
return self._training_data_que.get(block = False) | ||
except: | ||
return None | ||
return self._get_training_data() | ||
# try: | ||
# return self._training_data_que.get(block = False) | ||
# except: | ||
# # exec_method(self.logger, 'warning', True, "[Collector.pub_msg] training_data_que is empty!") | ||
# return None | ||
elif msg_type == MsgType.COLLECTOR_GET_BUFFER_LENGTH: | ||
return self.get_buffer_length() | ||
else: | ||
|
@@ -69,15 +72,13 @@ def _prepare_training_data(self): | |
''' | ||
while True: | ||
training_data = self.data_handler.sample_training_data() | ||
if training_data is None: | ||
continue | ||
else: | ||
if training_data is not None: | ||
try: | ||
self._training_data_que.put(training_data, block = False) | ||
except Full: | ||
# self.logger.warning("[Collector._sample_training_data] training_data_que is full!") | ||
# exec_method(self.logger, 'warning', True, "[Collector._prepare_training_data] training_data_que is full!") | ||
pass | ||
|
||
def _get_training_data(self): | ||
training_data = self.data_handler.sample_training_data() # sample training data | ||
return training_data | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-02 15:02:30 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-01-04 23:48:37 | ||
LastEditTime: 2024-01-06 22:54:09 | ||
Discription: | ||
''' | ||
import ray | ||
|
@@ -18,17 +18,19 @@ | |
from joyrl.framework.recorder import Recorder | ||
from joyrl.utils.utils import exec_method, create_module | ||
|
||
class Learner(Moduler): | ||
class Learner: | ||
''' learner | ||
''' | ||
def __init__(self, cfg : MergedConfig, **kwargs) -> None: | ||
super().__init__(cfg,**kwargs) | ||
self.cfg = cfg | ||
self.id = kwargs.get('id', 0) | ||
self.policy = kwargs.get('policy', None) | ||
self.policy_mgr = kwargs.get('policy_mgr', None) | ||
self.collector = kwargs.get('collector', None) | ||
self.tracker = kwargs.get('tracker', None) | ||
self.recorder = kwargs.get('recorder', None) | ||
self.logger = kwargs['logger'] | ||
self.use_ray = kwargs['use_ray'] | ||
self._init_update_steps() | ||
|
||
def _init_update_steps(self): | ||
|
@@ -40,19 +42,21 @@ def _init_update_steps(self): | |
def run(self): | ||
run_step = 0 | ||
while True: | ||
s_t = time.time() | ||
training_data = exec_method(self.collector, 'pub_msg', True, Msg(type = MsgType.COLLECTOR_GET_TRAINING_DATA)) | ||
if training_data is None: return | ||
self.policy.learn(**training_data) | ||
global_update_step = exec_method(self.tracker, 'pub_msg', True, Msg(type = MsgType.TRACKER_GET_UPDATE_STEP)) | ||
exec_method(self.tracker, 'pub_msg', True, Msg(type = MsgType.TRACKER_INCREASE_UPDATE_STEP)) | ||
# put updated model params to policy_mgr | ||
model_params = self.policy.get_model_params() | ||
exec_method(self.policy_mgr, 'pub_msg', True, Msg(type = MsgType.MODEL_MGR_PUT_MODEL_PARAMS, data = (global_update_step, model_params))) | ||
# put policy summary to recorder | ||
if global_update_step % self.cfg.policy_summary_fre == 0: | ||
policy_summary = [(global_update_step,self.policy.get_summary())] | ||
exec_method(self.recorder, 'pub_msg', True, Msg(type = MsgType.RECORDER_PUT_SUMMARY, data = policy_summary)) | ||
run_step += 1 | ||
if training_data is not None: | ||
self.policy.learn(**training_data) | ||
global_update_step = exec_method(self.tracker, 'pub_msg', True, Msg(type = MsgType.TRACKER_GET_UPDATE_STEP)) | ||
exec_method(self.tracker, 'pub_msg', False, Msg(type = MsgType.TRACKER_INCREASE_UPDATE_STEP)) | ||
# put updated model params to policy_mgr | ||
model_params = self.policy.get_model_params() | ||
exec_method(self.policy_mgr, 'pub_msg', False, Msg(type = MsgType.MODEL_MGR_PUT_MODEL_PARAMS, data = (global_update_step, model_params))) | ||
# put policy summary to recorder | ||
if global_update_step % self.cfg.policy_summary_fre == 0: | ||
policy_summary = [(global_update_step,self.policy.get_summary())] | ||
exec_method(self.recorder, 'pub_msg', False, Msg(type = MsgType.RECORDER_PUT_SUMMARY, data = policy_summary)) | ||
run_step += 1 | ||
# exec_method(self.logger, 'info', True, f"Learner {self.id} finished {run_step} update steps in {time.time() - s_t:.4f}s!") | ||
if run_step >= self.n_update_steps: | ||
return | ||
|
||
|
@@ -70,9 +74,14 @@ def __init__(self, cfg: MergedConfig, **kwargs) -> None: | |
policy_mgr = kwargs.get('policy_mgr', None), | ||
collector = kwargs.get('collector', None), | ||
tracker = kwargs.get('tracker', None), | ||
recorder = self.recorder) | ||
recorder = self.recorder, | ||
logger = self.logger, | ||
use_ray = self.use_ray, | ||
) | ||
for i in range(self.cfg.n_learners)] | ||
exec_method(self.logger, 'info', True, f"[LearnerMgr] Create {self.cfg.n_learners} learners!") | ||
def run(self): | ||
for i in range(self.cfg.n_learners): | ||
exec_method(self.learners[i], 'run', False) | ||
self.learners[i].run.remote() | ||
# for i in range(self.cfg.n_learners): | ||
# exec_method(self.learners[i], 'run', False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-22 23:02:13 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-01-04 23:52:45 | ||
LastEditTime: 2024-01-06 22:48:05 | ||
Discription: | ||
''' | ||
import time | ||
|
@@ -65,7 +65,7 @@ def _put_model_params(self, msg_data): | |
self.logger.warning.remote(f"[PolicyMgr._put_model_params] saved_model_que is full!") | ||
else: | ||
self.logger.warning(f"[PolicyMgr._put_model_params] saved_model_que is full!") | ||
time.sleep(0.001) | ||
# time.sleep(0.001) | ||
|
||
def _get_model_params(self): | ||
''' get policy | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-02 15:02:30 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-01-04 23:52:09 | ||
LastEditTime: 2024-01-06 22:39:42 | ||
Discription: | ||
''' | ||
import time | ||
|
@@ -32,7 +32,6 @@ def _print_cfgs(self): | |
''' | ||
def print_cfg(cfg, name = ''): | ||
cfg_dict = vars(cfg) | ||
exec_method(self.logger, 'info', True, ''.join(['='] * 80)) | ||
exec_method(self.logger, 'info', True, f"{name}:") | ||
exec_method(self.logger, 'info', True, ''.join(['='] * 80)) | ||
tplt = "{:^20}\t{:^20}\t{:^20}" | ||
|
@@ -59,6 +58,7 @@ def run(self): | |
exec_method(self.interactor_mgr, 'run', False) | ||
exec_method(self.learner_mgr, 'run', False) | ||
while True: | ||
time.sleep(0.1) | ||
if exec_method(self.tracker, 'pub_msg', True, Msg(type = MsgType.TRACKER_CHECK_TASK_END)): | ||
e_t = time.time() | ||
exec_method(self.logger, 'info', True, f"[Trainer.run] Finish {self.cfg.mode}ing! Time cost: {e_t - s_t:.3f} s") | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-22 13:16:59 | ||
LastEditor: JiangJi | ||
LastEditTime: 2023-12-24 19:00:32 | ||
LastEditTime: 2024-01-06 22:01:23 | ||
Discription: | ||
''' | ||
import sys,os | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ | |
Email: [email protected] | ||
Date: 2023-12-22 13:16:59 | ||
LastEditor: JiangJi | ||
LastEditTime: 2024-01-04 23:02:39 | ||
LastEditTime: 2024-01-06 22:02:03 | ||
Discription: | ||
''' | ||
import sys,os | ||
|
@@ -172,8 +172,8 @@ def run(self) -> None: | |
if self.cfg.n_interactors > 1: | ||
is_remote = True | ||
ray.init() | ||
if self.cfg.online_eval: | ||
online_tester = create_module(OnlineTester, False, {'num_cpus':0}, self.cfg, env = env, policy = policy) | ||
# if self.cfg.online_eval: | ||
# online_tester = create_module(OnlineTester, False, {'num_cpus':0}, self.cfg, env = env, policy = policy) | ||
tracker = create_module(Tracker, is_remote, {'num_cpus':0}, self.cfg) | ||
collector = create_module(Collector, is_remote, {'num_cpus':1}, self.cfg, data_handler = data_handler) | ||
policy_mgr = create_module(PolicyMgr, is_remote, {'num_cpus':0}, self.cfg, policy = policy) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,5 +12,5 @@ pygame==2.1.2 | |
glfw==2.5.5 | ||
imageio==2.22.4 | ||
tensorboard==2.11.2 | ||
ray==2.6.3 | ||
ray[default]==2.6.3 | ||
gymnasium==0.28.1 |