Skip to content

Commit

Permalink
feature(xjx): add env supervisor (opendilab#330)
Browse files Browse the repository at this point in the history
* Process supervisor

Add thread supervisor

Add env supervisor

Env launch and reset (ci skip)

Rewrite ignore_err in supervisor (ci skip)

Support recv req id in supervisor (ci skip)

Test env supervisor (ci skip)

Stash env supervisor (ci skip)

* Add __getattr__ on supervisor (ci skip)

* Step retry (ci skip)

* Add timeout on recv all

* Refactor recv_all

* Refactor recv_all callback

* Test block

* Add different process type

* Add async process in env supervisor

* Add comments, support space attributes

* Fix style (ci skip)

* Fix empty ready_obs

* Support episode num (ci skip)

* Fix env supervisor

* Add test for ready obs

* Add process name

* Fix supervisor

* Change platformtest

* Fix auto reset

* Fix lru cache in 3.6

* Fix Error 32

* Support shared memory

* Prevent missing attrs

* Fix

* Fix env retry once test

* Fix test shared memory

* Fix close

* Fix comments

* Update collector profile test, fix deprecated torch ops
  • Loading branch information
sailxjx authored Jun 9, 2022
1 parent e477a08 commit 242ca2b
Show file tree
Hide file tree
Showing 19 changed files with 1,699 additions and 36 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ dockertest:
${DING_DIR}/scripts/docker-test-entry.sh

platformtest:
pytest ${PLATFORM_TEST_DIR} \
pytest ${TEST_DIR} \
--cov-report term-missing \
--cov=${COV_DIR} \
${WORKERS_COMMAND} \
-sv -m unittest \
-sv -m platformtest

benchmark:
pytest ${TEST_DIR} \
Expand Down
21 changes: 2 additions & 19 deletions ding/data/buffer/buffer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from abc import abstractmethod
from abc import abstractmethod, ABC
from typing import Any, List, Optional, Union, Callable
import copy
from dataclasses import dataclass
Expand Down Expand Up @@ -53,7 +53,7 @@ def _copy_buffereddata(d: BufferedData) -> BufferedData:
fastcopy.dispatch[BufferedData] = _copy_buffereddata


class Buffer:
class Buffer(ABC):
"""
Buffer is an abstraction of device storage, third-party services or data structures,
For example, memory queue, sum-tree, redis, or di-store.
Expand Down Expand Up @@ -119,23 +119,6 @@ def update(self, index: str, data: Optional[Any] = None, meta: Optional[dict] =
"""
raise NotImplementedError

@abstractmethod
def batch_update(
self,
indices: List[str],
datas: Optional[List[Optional[Any]]] = None,
metas: Optional[List[Optional[dict]]] = None
) -> None:
"""
Overview:
Batch update data and meta by indices, maybe useful in some data architectures.
Arguments:
- indices (:obj:`List[str]`): Index of data.
- datas (:obj:`Optional[List[Optional[Any]]]`): Pure data.
- metas (:obj:`Optional[List[Optional[dict]]]`): Meta information.
"""
raise NotImplementedError

@abstractmethod
def delete(self, index: str):
"""
Expand Down
36 changes: 36 additions & 0 deletions ding/entry/tests/test_serial_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from dizoo.gym_hybrid.config.gym_hybrid_mpdqn_config import gym_hybrid_mpdqn_config, gym_hybrid_mpdqn_create_config


@pytest.mark.platformtest
@pytest.mark.unittest
def test_dqn():
config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)]
Expand All @@ -63,6 +64,7 @@ def test_dqn():
os.popen('rm -rf cartpole_dqn_unittest')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_ddpg():
config = [deepcopy(pendulum_ddpg_config), deepcopy(pendulum_ddpg_create_config)]
Expand All @@ -73,6 +75,7 @@ def test_ddpg():
assert False, "pipeline fail"


# @pytest.mark.platformtest
# @pytest.mark.unittest
def test_hybrid_ddpg():
config = [deepcopy(gym_hybrid_ddpg_config), deepcopy(gym_hybrid_ddpg_create_config)]
Expand All @@ -83,6 +86,7 @@ def test_hybrid_ddpg():
assert False, "pipeline fail"


# @pytest.mark.platformtest
# @pytest.mark.unittest
def test_hybrid_pdqn():
config = [deepcopy(gym_hybrid_pdqn_config), deepcopy(gym_hybrid_pdqn_create_config)]
Expand All @@ -93,6 +97,7 @@ def test_hybrid_pdqn():
assert False, "pipeline fail"


# @pytest.mark.platformtest
# @pytest.mark.unittest
def test_hybrid_mpdqn():
config = [deepcopy(gym_hybrid_mpdqn_config), deepcopy(gym_hybrid_mpdqn_create_config)]
Expand All @@ -103,6 +108,7 @@ def test_hybrid_mpdqn():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_dqn_stdim():
config = [deepcopy(cartpole_dqn_stdim_config), deepcopy(cartpole_dqn_stdim_create_config)]
Expand All @@ -116,6 +122,7 @@ def test_dqn_stdim():
os.popen('rm -rf cartpole_dqn_stdim_unittest')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_td3():
config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)]
Expand All @@ -126,6 +133,7 @@ def test_td3():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_rainbow():
config = [deepcopy(cartpole_rainbow_config), deepcopy(cartpole_rainbow_create_config)]
Expand All @@ -136,6 +144,7 @@ def test_rainbow():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_iqn():
config = [deepcopy(cartpole_iqn_config), deepcopy(cartpole_iqn_create_config)]
Expand All @@ -146,6 +155,7 @@ def test_iqn():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_c51():
config = [deepcopy(cartpole_c51_config), deepcopy(cartpole_c51_create_config)]
Expand All @@ -156,6 +166,7 @@ def test_c51():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_qrdqn():
config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)]
Expand All @@ -166,6 +177,7 @@ def test_qrdqn():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_ppo():
config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
Expand All @@ -177,6 +189,7 @@ def test_ppo():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_ppo_nstep_return():
config = [deepcopy(cartpole_offppo_config), deepcopy(cartpole_offppo_create_config)]
Expand All @@ -188,6 +201,7 @@ def test_ppo_nstep_return():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_sac():
config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)]
Expand All @@ -199,6 +213,7 @@ def test_sac():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_sac_auto_alpha():
config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)]
Expand All @@ -211,6 +226,7 @@ def test_sac_auto_alpha():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_sac_log_space():
config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)]
Expand All @@ -228,6 +244,7 @@ def test_sac_log_space():
args = [item for item in product(*[auto_alpha, log_space])]


@pytest.mark.platformtest
@pytest.mark.unittest
@pytest.mark.parametrize('auto_alpha, log_space', args)
def test_discrete_sac(auto_alpha, log_space):
Expand All @@ -241,6 +258,7 @@ def test_discrete_sac(auto_alpha, log_space):
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_discrete_sac_twin_critic():
config = [deepcopy(cartpole_sac_config), deepcopy(cartpole_sac_create_config)]
Expand All @@ -255,6 +273,7 @@ def test_discrete_sac_twin_critic():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_r2d2():
config = [deepcopy(cartpole_r2d2_config), deepcopy(cartpole_r2d2_create_config)]
Expand All @@ -265,6 +284,7 @@ def test_r2d2():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_impala():
config = [deepcopy(cartpole_impala_config), deepcopy(cartpole_impala_create_config)]
Expand All @@ -275,6 +295,7 @@ def test_impala():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_her_dqn():
bitflip_her_dqn_config.policy.cuda = False
Expand All @@ -284,6 +305,7 @@ def test_her_dqn():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_collaq():
config = [deepcopy(ptz_simple_spread_collaq_config), deepcopy(ptz_simple_spread_collaq_create_config)]
Expand All @@ -298,6 +320,7 @@ def test_collaq():
os.popen('rm -rf log ckpt*')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_coma():
config = [deepcopy(ptz_simple_spread_coma_config), deepcopy(ptz_simple_spread_coma_create_config)]
Expand All @@ -312,6 +335,7 @@ def test_coma():
os.popen('rm -rf log ckpt*')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_qmix():
config = [deepcopy(ptz_simple_spread_qmix_config), deepcopy(ptz_simple_spread_qmix_create_config)]
Expand All @@ -326,6 +350,7 @@ def test_qmix():
os.popen('rm -rf log ckpt*')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_wqmix():
config = [deepcopy(ptz_simple_spread_wqmix_config), deepcopy(ptz_simple_spread_wqmix_create_config)]
Expand All @@ -340,6 +365,7 @@ def test_wqmix():
os.popen('rm -rf log ckpt*')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_qtran():
config = [deepcopy(ptz_simple_spread_qtran_config), deepcopy(ptz_simple_spread_qtran_create_config)]
Expand All @@ -354,6 +380,7 @@ def test_qtran():
os.popen('rm -rf log ckpt*')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_atoc():
config = [deepcopy(ptz_simple_spread_atoc_config), deepcopy(ptz_simple_spread_atoc_create_config)]
Expand All @@ -367,6 +394,7 @@ def test_atoc():
os.popen('rm -rf log ckpt*')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_ppg():
cartpole_ppg_config.policy.use_cuda = False
Expand All @@ -376,6 +404,7 @@ def test_ppg():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_sqn():
config = [deepcopy(cartpole_sqn_config), deepcopy(cartpole_sqn_create_config)]
Expand All @@ -389,6 +418,7 @@ def test_sqn():
os.popen('rm -rf log ckpt*')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_selfplay():
try:
Expand All @@ -397,6 +427,7 @@ def test_selfplay():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_league():
try:
Expand All @@ -405,6 +436,7 @@ def test_league():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_acer():
config = [deepcopy(cartpole_acer_config), deepcopy(cartpole_acer_create_config)]
Expand All @@ -415,6 +447,7 @@ def test_acer():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_cql():
# train expert
Expand Down Expand Up @@ -449,6 +482,7 @@ def test_cql():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_d4pg():
config = [deepcopy(pendulum_d4pg_config), deepcopy(pendulum_d4pg_create_config)]
Expand All @@ -460,6 +494,7 @@ def test_d4pg():
print(repr(e))


@pytest.mark.platformtest
@pytest.mark.unittest
def test_discrete_cql():
# train expert
Expand Down Expand Up @@ -492,6 +527,7 @@ def test_discrete_cql():
os.popen('rm -rf cartpole cartpole_cql')


@pytest.mark.platformtest
@pytest.mark.unittest
def test_td3_bc():
# train expert
Expand Down
4 changes: 4 additions & 0 deletions ding/entry/tests/test_serial_entry_onpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dizoo.classic_control.pendulum.config.pendulum_ppo_config import pendulum_ppo_config, pendulum_ppo_create_config


@pytest.mark.platformtest
@pytest.mark.unittest
def test_a2c():
config = [deepcopy(cartpole_a2c_config), deepcopy(cartpole_a2c_create_config)]
Expand All @@ -19,6 +20,7 @@ def test_a2c():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_onpolicy_ppo():
config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
Expand All @@ -30,6 +32,7 @@ def test_onpolicy_ppo():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_mappo():
config = [deepcopy(ptz_simple_spread_mappo_config), deepcopy(ptz_simple_spread_mappo_create_config)]
Expand All @@ -40,6 +43,7 @@ def test_mappo():
assert False, "pipeline fail"


@pytest.mark.platformtest
@pytest.mark.unittest
def test_onpolicy_ppo_continuous():
config = [deepcopy(pendulum_ppo_config), deepcopy(pendulum_ppo_create_config)]
Expand Down
2 changes: 1 addition & 1 deletion ding/envs/common/common_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def compute_denominator(x: torch.Tensor) -> torch.Tensor:
Returns:
- ret (:obj:`torch.Tensor`):
"""
x = x // 2 * 2
x = torch.div(x, 2, rounding_mode='trunc') * 2
x = torch.div(x, 64.)
x = torch.pow(10000., x)
x = torch.div(1., x)
Expand Down
2 changes: 1 addition & 1 deletion ding/envs/env/env_implementation_check.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from tabnanny import check
from typing import Any, Callable, List, Tuple
import numpy as np
from collections import Sequence
from collections.abc import Sequence
from easydict import EasyDict

from ding.envs.env import BaseEnv, BaseEnvTimestep
Expand Down
1 change: 1 addition & 0 deletions ding/envs/env_manager/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .subprocess_env_manager import AsyncSubprocessEnvManager, SyncSubprocessEnvManager, SubprocessEnvManagerV2
from .gym_vector_env_manager import GymVectorEnvManager
# Do not import PoolEnvManager, because it depends on installation of `envpool`
from .env_supervisor import EnvSupervisor
Loading

0 comments on commit 242ca2b

Please sign in to comment.