Skip to content

Commit

Permalink
[feat_dev] add onlinetest data to TB
Browse files Browse the repository at this point in the history
  • Loading branch information
johnjim0816 committed May 31, 2024
1 parent 4f0a05c commit a91caee
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 16 deletions.
3 changes: 1 addition & 2 deletions joyrl/algos/PPO/data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
Email: [email protected]
Date: 2023-05-17 01:08:36
LastEditor: JiangJi
LastEditTime: 2024-05-27 13:56:10
LastEditTime: 2024-05-31 11:19:41
Discription:
'''
import numpy as np
import scipy
import torch
from joyrl.algos.base.data_handler import BaseDataHandler
class DataHandler(BaseDataHandler):
Expand Down
2 changes: 2 additions & 0 deletions joyrl/algos/base/buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,10 @@ class OnPolicyBufferQue(ReplayBufferQue):
def __init__(self, cfg: MergedConfig):
self.cfg = cfg
self.buffer = deque(maxlen=10)

def push(self, exps: list):
self.buffer.append(exps)

def sample(self,**kwargs):
''' sample all the transitions
'''
Expand Down
7 changes: 4 additions & 3 deletions joyrl/framework/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-22 23:02:13
LastEditor: JiangJi
LastEditTime: 2024-05-30 17:42:50
LastEditTime: 2024-05-31 11:21:07
Discription:
'''
import ray
Expand All @@ -25,7 +25,7 @@ def __init__(self, cfg: MergedConfig, **kwargs) -> None:
super().__init__(cfg, **kwargs)
self.data_handler = kwargs['data_handler']
self._raw_exps_que = kwargs['raw_exps_que']
self._training_data_que = Queue(maxsize = 2)
self._training_data_que = Queue(maxsize = 1)
self._t_start()

def _t_start(self):
Expand Down Expand Up @@ -100,7 +100,8 @@ def _prepare_training_data(self):
s_t = time.time()
consumed_exp_len = 0
except Full:
exec_method(self.logger, 'warning', 'get', "[Collector._prepare_training_data] training_data_que is full!")
pass
# exec_method(self.logger, 'warning', 'get', "[Collector._prepare_training_data] training_data_que is full!")
# time.sleep(0.002)

def _get_training_data(self):
Expand Down
6 changes: 3 additions & 3 deletions joyrl/framework/interactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2024-02-25 15:46:04
LastEditor: JiangJi
LastEditTime: 2024-05-30 17:55:42
LastEditTime: 2024-05-31 11:33:13
Discription:
'''
import gymnasium as gym
Expand Down Expand Up @@ -58,7 +58,7 @@ def _put_exps(self):
self._raw_exps_que.put(self.exps, block=True, timeout=0.1)
break
except:
exec_method(self.logger, 'warning', 'get', "[Interactor._put_exps] raw_exps_que is full!")
# exec_method(self.logger, 'warning', 'get', "[Interactor._put_exps] raw_exps_que is full!")
time.sleep(0.1)
self.exps = []

Expand All @@ -75,7 +75,7 @@ def run(self):
''' run in sync mode
'''
run_step = 0 # local run step
while True:
while True:
action = self.policy.get_action(self.curr_obs)
obs, reward, terminated, truncated, info = self.env.step(action)
interact_transition = {'interactor_id': self.id, 'state': self.curr_obs, 'action': action,'reward': reward, 'next_state': obs, 'done': terminated or truncated, 'info': info}
Expand Down
5 changes: 4 additions & 1 deletion joyrl/framework/tester.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-02 15:02:30
LastEditor: JiangJi
LastEditTime: 2024-05-30 17:48:29
LastEditTime: 2024-05-31 11:29:45
Discription:
'''
import time
Expand All @@ -14,6 +14,7 @@
import threading
from joyrl.framework.config import MergedConfig
from joyrl.framework.base import Moduler
from joyrl.framework.message import Msg, MsgType
from joyrl.utils.utils import exec_method

class OnlineTester(Moduler):
Expand All @@ -23,6 +24,7 @@ def __init__(self, cfg : MergedConfig, *args, **kwargs) -> None:
super().__init__(cfg, *args, **kwargs)
self.env = copy.deepcopy(kwargs['env'])
self.policy = copy.deepcopy(kwargs['policy'])
self.recorder = kwargs['recorder']
self.seed = self.cfg.seed
self.best_eval_reward = -float('inf')
self.curr_test_step = -1
Expand Down Expand Up @@ -73,6 +75,7 @@ def _eval_policy(self):
pass
mean_eval_reward = sum_eval_reward / self.cfg.online_eval_episode
exec_method(self.logger, 'info', 'get', f"online_eval step: {self.curr_test_step}, online_eval_reward: {mean_eval_reward:.3f}")
exec_method(self.recorder, 'pub_msg', 'remote', Msg(type = MsgType.RECORDER_PUT_SUMMARY, data = [(model_step, {'online_eval_reward': mean_eval_reward})])) # put summary to stats recorder
# logger_info = f"test_step: {self.curr_test_step}, online_eval_reward: {mean_eval_reward:.3f}"
# self.logger.info.remote(logger_info) if self.use_ray else self.logger.info(logger_info)
if mean_eval_reward >= self.best_eval_reward:
Expand Down
17 changes: 10 additions & 7 deletions joyrl/framework/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
Email: [email protected]
Date: 2023-12-02 15:02:30
LastEditor: JiangJi
LastEditTime: 2024-05-30 18:02:34
LastEditTime: 2024-05-31 11:27:31
Discription:
'''
import copy
Expand Down Expand Up @@ -34,6 +34,7 @@ def __init__(self, cfg: MergedConfig,**kwargs) -> None:
self.env = kwargs['env']
self.policy = kwargs['policy']
self.data_handler = kwargs['data_handler']
# self.cfg.on_policy = False
self._print_cfgs() # print parameters
self._create_shared_data() # create data queues
self._create_modules() # create modules
Expand All @@ -46,11 +47,15 @@ def _create_modules(self):
''' create modules
'''
if self.cfg.online_eval:
recorder = ray.remote(Recorder).options(**{'num_cpus': 0}).remote(self.cfg,
name = 'RecorderOnlineTester',
type = 'online_tester')
self.online_tester = OnlineTester(
self.cfg,
name = 'OnlineTester',
env = copy.deepcopy(self.env),
policy = copy.deepcopy(self.policy),
recorder = recorder,
)
self.tracker = ray.remote(Tracker).remote(self.cfg)
self.collector = ray.remote(Collector).options(**{'num_cpus': 1}).remote(
Expand Down Expand Up @@ -131,7 +136,6 @@ def run(self):
'''
exec_method(self.logger, 'info', 'get', f"[Trainer.run] Start {self.cfg.mode}ing!")
s_t = time.time()
# self.cfg.on_policy = False
if self.cfg.on_policy:
while True:
ray.get([interactor.run.remote() for interactor in self.interactors])
Expand All @@ -143,14 +147,13 @@ def run(self):
ray.shutdown()
break
else:
[interactor.run.remote() for interactor in self.interactors]
[learner.run.remote() for learner in self.learners]
while True:
for interactor in self.interactors:
interactor.run.remote()
for learner in self.learners:
learner.run.remote()
if exec_method(self.tracker, 'pub_msg', 'get', Msg(type = MsgType.TRACKER_CHECK_TASK_END)):
e_t = time.time()
exec_method(self.logger, 'info', 'get', f"[Trainer.run] Finish {self.cfg.mode}ing! Time cost: {e_t - s_t:.3f} s")
time.sleep(5)
ray.shutdown()
break
break
time.sleep(1)

0 comments on commit a91caee

Please sign in to comment.