From 8da42485fda32c9ff400dc3a5c1e08d9a46997c0 Mon Sep 17 00:00:00 2001 From: huangshiyu Date: Tue, 8 Aug 2023 13:21:25 +0800 Subject: [PATCH] add shuffle traj to offline env --- examples/behavior_cloning/test_env.py | 7 +++++-- openrl/envs/offline/__init__.py | 1 - openrl/envs/offline/offline_env.py | 19 +++++++++++++++---- openrl/envs/vec_env/base_venv.py | 3 ++- openrl/envs/vec_env/sync_venv.py | 2 ++ .../vec_env/wrappers/vec_monitor_wrapper.py | 1 + 6 files changed, 25 insertions(+), 8 deletions(-) diff --git a/examples/behavior_cloning/test_env.py b/examples/behavior_cloning/test_env.py index bf1c3c21..b69fb16e 100644 --- a/examples/behavior_cloning/test_env.py +++ b/examples/behavior_cloning/test_env.py @@ -11,16 +11,19 @@ def test_env(): cfg = cfg_parser.parse_args() # create environment, set environment parallelism to 9 - env = make("OfflineEnv", env_num=1, cfg=cfg, asynchronous=True) + # env = make("OfflineEnv", env_num=1, cfg=cfg, asynchronous=True) + env = make("OfflineEnv", env_num=1, cfg=cfg, asynchronous=False) for ep_index in range(10): done = False step = 0 env.reset() + while not np.all(done): obs, reward, done, info = env.step(env.random_action()) + step += 1 - print(ep_index, step) + print(ep_index, step) if __name__ == "__main__": diff --git a/openrl/envs/offline/__init__.py b/openrl/envs/offline/__init__.py index 0c3174e4..4d61f510 100644 --- a/openrl/envs/offline/__init__.py +++ b/openrl/envs/offline/__init__.py @@ -29,7 +29,6 @@ def offline_make(dataset, render_mode, disable_env_checker, **kwargs): env_num = kwargs["env_num"] seed = kwargs.pop("seed", None) assert seed is not None, "seed must be set" - env = OfflineEnv(dataset, env_id, env_num, seed) return env diff --git a/openrl/envs/offline/offline_env.py b/openrl/envs/offline/offline_env.py index 17196548..ad04cae2 100644 --- a/openrl/envs/offline/offline_env.py +++ b/openrl/envs/offline/offline_env.py @@ -37,8 +37,11 @@ def __init__(self, dataset_path, env_id: int, env_num: int, seed: int): self.agent_num = self.dataset.agent_num self.traj_num = len(self.dataset.trajectories["episode_lengths"]) self.traj_index = None + self.epoch_index = None self.traj_length = None self.step_index = None + self.sample_indexes = None + self.seed(seed) def seed(self, seed=None): if seed is not None: @@ -48,11 +51,19 @@ def reset(self, *, seed=None, options=None): if seed is not None: self.seed(seed) - if self.traj_index is None: - self.traj_index = 0 + if self.epoch_index is None: + self.epoch_index = 0 else: - self.traj_index += 1 - self.traj_index %= self.traj_num + self.epoch_index += 1 + self.epoch_index %= self.traj_num + if self.epoch_index == 0: + if self._np_random is None: + self.seed(0) + self.sample_indexes = self._np_random.permutation(self.traj_num) + + assert self.sample_indexes is not None + self.traj_index = self.sample_indexes[self.epoch_index] + self.traj_length = self.dataset.trajectories["episode_lengths"][self.traj_index] assert ( self.traj_length diff --git a/openrl/envs/vec_env/base_venv.py b/openrl/envs/vec_env/base_venv.py index 0a3a2221..f2e54744 100644 --- a/openrl/envs/vec_env/base_venv.py +++ b/openrl/envs/vec_env/base_venv.py @@ -116,6 +116,7 @@ def step(self, actions): :param actions: the action :return: observation, reward, done, information """ + results = self._step(actions) self.vector_render() return results @@ -315,7 +316,7 @@ def random_action(self, infos: Optional[List[Dict[str, Any]]] = None): action_masks = prepare_action_masks( infos, agent_num=self.agent_num, as_batch=False ) - print(action_masks) + return np.array( [ [ diff --git a/openrl/envs/vec_env/sync_venv.py b/openrl/envs/vec_env/sync_venv.py index 433240f0..a670ec33 100644 --- a/openrl/envs/vec_env/sync_venv.py +++ b/openrl/envs/vec_env/sync_venv.py @@ -180,12 +180,14 @@ def _step(self, actions: ActType): Returns: The batched environment step results """ + _actions = iterate_action(self.action_space, actions) observations, infos = [], [] for i, (env, action) in enumerate(zip(self.envs, _actions)): returns = env.step(action) + assert isinstance( returns, tuple ), "step return must be tuple, but got: {}".format(type(returns)) diff --git a/openrl/envs/vec_env/wrappers/vec_monitor_wrapper.py b/openrl/envs/vec_env/wrappers/vec_monitor_wrapper.py index ebc00339..d769e04e 100644 --- a/openrl/envs/vec_env/wrappers/vec_monitor_wrapper.py +++ b/openrl/envs/vec_env/wrappers/vec_monitor_wrapper.py @@ -35,6 +35,7 @@ def use_monitor(self): def step(self, action: ActType, extra_data: Optional[Dict[str, Any]] = None): returns = self.env.step(action, extra_data) + self.vec_info.append(info=returns[-1]) return returns