Skip to content

Commit

Permalink
pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
reginald-mclean committed Nov 6, 2024
1 parent 279136c commit e1b74d2
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions tests/metaworld/test_gym_make.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def _get_task_names(
for task_name in envs.get_attr("task_name")
]


@pytest.mark.parametrize("benchmark,env_dict", (("MT10", MT10_V3), ("MT50", MT50_V3)))
@pytest.mark.parametrize("vector_strategy", ("sync", "async"))
def test_mt_benchmarks(benchmark: str, env_dict: EnvDict, vector_strategy: str):
Expand Down Expand Up @@ -91,33 +92,37 @@ def test_mt_benchmarks(benchmark: str, env_dict: EnvDict, vector_strategy: str):
partially_observable = all(envs.get_attr("_partially_observable"))
assert not partially_observable


@pytest.mark.parametrize("env_name", ALL_V3_ENVIRONMENTS.keys())
def test_mt1(env_name: str):
metaworld_cls_to_task_name = {v.__name__: k for k, v in ALL_V3_ENVIRONMENTS.items()}
env = gym.make(f"Meta-World/MT1", env_name=env_name)
env = gym.make("Meta-World/MT1", env_name=env_name)
assert isinstance(env.unwrapped, SawyerXYZEnv)
assert len(env.get_wrapper_attr("tasks")) == _N_GOALS
assert metaworld_cls_to_task_name[env.unwrapped.task_name] == env_name

env.reset()
assert not env.unwrapped._partially_observable


@pytest.mark.parametrize("env_name", ALL_V3_ENVIRONMENTS_GOAL_HIDDEN.keys())
def test_goal_hidden(env_name: str):
env = gym.make(f"Meta-World/goal_hidden", env_name=env_name, seed=None)
env = gym.make("Meta-World/goal_hidden", env_name=env_name, seed=None)
assert isinstance(env.unwrapped, SawyerXYZEnv)

env.reset()
assert env.unwrapped._partially_observable


@pytest.mark.parametrize("env_name", ALL_V3_ENVIRONMENTS_GOAL_OBSERVABLE.keys())
def test_goal_observable(env_name: str):
env = gym.make(f"Meta-World/goal_observable", env_name=env_name, seed=None)
env = gym.make("Meta-World/goal_observable", env_name=env_name, seed=None)
assert isinstance(env.unwrapped, SawyerXYZEnv)

env.reset()
assert not env.unwrapped._partially_observable


@pytest.mark.parametrize("env_name", ALL_V3_ENVIRONMENTS.keys())
@pytest.mark.parametrize("split", ("train", "test"))
@pytest.mark.parametrize("vector_strategy", ("sync", "async"))
Expand Down

0 comments on commit e1b74d2

Please sign in to comment.