Skip to content

Commit

Permalink
Fix SetTaskEnvUpdate when new env is a supertype (#2264)
Browse files Browse the repository at this point in the history
Before, if the new environment was a subtype of the existing
environment, the environment was not changed, even though it should be.
  • Loading branch information
krzentner authored Apr 12, 2021
1 parent 3fb4f50 commit cb88ffa
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/garage/sampler/env_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,11 @@ def __call__(self, old_env=None):
Environment: The new, updated environment.
"""
# We need exact type equality, not just a subtype
# pylint: disable=unidiomatic-typecheck
if old_env is None:
return self._make_env()
elif not isinstance(getattr(old_env, 'unwrapped', old_env),
self._env_type):
elif type(getattr(old_env, 'unwrapped', old_env)) != self._env_type:
warnings.warn('SetTaskEnvUpdate is closing an environment. This '
'may indicate a very slow TaskSampler setup.')
old_env.close()
Expand Down
24 changes: 24 additions & 0 deletions tests/garage/sampler/test_env_update.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from garage.sampler import SetTaskUpdate

from tests.fixtures.envs.dummy import DummyBoxEnv

TEST_TASK = ['test_task']


class MTDummyEnv(DummyBoxEnv):

def set_task(self, task):
assert task == TEST_TASK


class MTDummyEnvSubtype(MTDummyEnv):
pass


def test_set_task_update_with_subtype():
old_env = MTDummyEnvSubtype()
env_update = SetTaskUpdate(MTDummyEnv, TEST_TASK, None)
new_env = env_update(old_env)
assert new_env is not old_env
assert new_env is not None
assert old_env is not None

0 comments on commit cb88ffa

Please sign in to comment.