From cb88ffab2a5f4c669cd336a8219c088c59dc7176 Mon Sep 17 00:00:00 2001 From: "K.R. Zentner" <41180126+krzentner@users.noreply.github.com> Date: Mon, 12 Apr 2021 12:57:37 -0700 Subject: [PATCH] Fix SetTaskEnvUpdate when new env is a supertype (#2264) Before, if the new environment was a subtype of the existing environment, the environment was not changed, even though it should be. --- src/garage/sampler/env_update.py | 5 +++-- tests/garage/sampler/test_env_update.py | 24 ++++++++++++++++++++++++ 2 files changed, 27 insertions(+), 2 deletions(-) create mode 100644 tests/garage/sampler/test_env_update.py diff --git a/src/garage/sampler/env_update.py b/src/garage/sampler/env_update.py index cc0cbe379f..de12b72913 100644 --- a/src/garage/sampler/env_update.py +++ b/src/garage/sampler/env_update.py @@ -107,10 +107,11 @@ def __call__(self, old_env=None): Environment: The new, updated environment. """ + # We need exact type equality, not just a subtype + # pylint: disable=unidiomatic-typecheck if old_env is None: return self._make_env() - elif not isinstance(getattr(old_env, 'unwrapped', old_env), - self._env_type): + elif type(getattr(old_env, 'unwrapped', old_env)) != self._env_type: warnings.warn('SetTaskEnvUpdate is closing an environment. This ' 'may indicate a very slow TaskSampler setup.') old_env.close() diff --git a/tests/garage/sampler/test_env_update.py b/tests/garage/sampler/test_env_update.py new file mode 100644 index 0000000000..a62be8d788 --- /dev/null +++ b/tests/garage/sampler/test_env_update.py @@ -0,0 +1,24 @@ +from garage.sampler import SetTaskUpdate + +from tests.fixtures.envs.dummy import DummyBoxEnv + +TEST_TASK = ['test_task'] + + +class MTDummyEnv(DummyBoxEnv): + + def set_task(self, task): + assert task == TEST_TASK + + +class MTDummyEnvSubtype(MTDummyEnv): + pass + + +def test_set_task_update_with_subtype(): + old_env = MTDummyEnvSubtype() + env_update = SetTaskUpdate(MTDummyEnv, TEST_TASK, None) + new_env = env_update(old_env) + assert new_env is not old_env + assert new_env is not None + assert old_env is not None