From c216f7303854099f1602a4c104927b9ecca0a159 Mon Sep 17 00:00:00 2001 From: lutogniew Date: Wed, 24 May 2023 17:46:00 +0200 Subject: [PATCH 1/2] Fix env checker single-step-env edge case Before this change, env checker failed to `reset()` the tested environment before calling `step()` when checking for `Inf` / `NaN`. This could cause environments which happened to have only one `step()` available before the episode was terminated to fail. This is now fixed. --- docs/misc/changelog.rst | 5 ++- stable_baselines3/common/env_checker.py | 1 + stable_baselines3/version.txt | 2 +- tests/test_env_checker.py | 57 ++++++++++++++++++++++++- 4 files changed, 62 insertions(+), 3 deletions(-) diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 16a7737e5..d2a3d6060 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -3,7 +3,7 @@ Changelog ========== -Release 2.0.0a10 (WIP) +Release 2.0.0a11 (WIP) -------------------------- **Gymnasium support** @@ -39,6 +39,8 @@ Bug Fixes: - Fixed ``VecExtractDictObs`` does not handle terminal observation (@WeberSamuel) - Set NumPy version to ``>=1.20`` due to use of ``numpy.typing`` (@troiganto) - Fixed loading DQN changes ``target_update_interval`` (@tobirohrer) +- Fixed env checker to properly reset the env before calling ``step()`` when checking + for ``Inf`` and ``NaN`` (@lutogniew) Deprecations: ^^^^^^^^^^^^^ @@ -1346,3 +1348,4 @@ And all the contributors: @Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto +@lutogniew diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index 058710df9..b6ce490df 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -110,6 +110,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act def _check_nan(env: gym.Env) -> None: """Check for Inf and NaN using the VecWrapper.""" vec_env = VecCheckNan(DummyVecEnv([lambda: env])) + vec_env.reset() for _ in range(10): action = np.array([env.action_space.sample()]) _, _, _, _ = vec_env.step(action) diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 7385c4c8b..d70b1bb71 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -2.0.0a10 +2.0.0a11 diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index e855e2137..829281d8e 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Any, Dict, Optional, Tuple import gymnasium as gym import numpy as np @@ -112,3 +112,58 @@ def step(self, action): test_env = TestEnv() with pytest.raises(AssertionError, match=error_message): check_env(env=test_env) + + +class StepCalledAfterEnvTerminatedException(Exception): + pass + + +class LimitedStepsTestEnv(gym.Env): + metadata = {"render_modes": ["human"]} + render_mode = None + + action_space = spaces.Discrete(n=2) + observation_space = spaces.Discrete(n=2) + + def __init__(self, steps_before_termination: int = 1): + super().__init__() + + assert steps_before_termination >= 1 + self._steps_before_termination = steps_before_termination + + self._steps_called = 0 + self._terminated = False + + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[int, Dict]: + super().reset(seed=seed) + + self._steps_called = 0 + self._terminated = False + + return 0, {} + + def step(self, action: np.ndarray) -> Tuple[int, float, bool, bool, Dict[str, Any]]: + self._steps_called += 1 + + if self._terminated: + raise StepCalledAfterEnvTerminatedException + + observation = 0 + reward = 0.0 + self._terminated = self._steps_called >= self._steps_before_termination + truncated = False + + return observation, reward, self._terminated, truncated, {} + + def render(self, mode: str = "human") -> None: + pass + + def close(self): + pass + + +def test_check_env_single_step_env(): + test_env = LimitedStepsTestEnv(steps_before_termination=1) + + # This should not throw + check_env(env=test_env, warn=True) From 5a2cde7cb9ccb4dfccb74e7a14c2a214fb1e16c0 Mon Sep 17 00:00:00 2001 From: lutogniew Date: Thu, 25 May 2023 15:34:32 +0200 Subject: [PATCH 2/2] Code review fixes #1 As suggested by Antonin Raffin . --- tests/test_env_checker.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index 829281d8e..c0a5e0610 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -114,14 +114,7 @@ def step(self, action): check_env(env=test_env) -class StepCalledAfterEnvTerminatedException(Exception): - pass - - class LimitedStepsTestEnv(gym.Env): - metadata = {"render_modes": ["human"]} - render_mode = None - action_space = spaces.Discrete(n=2) observation_space = spaces.Discrete(n=2) @@ -145,8 +138,7 @@ def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) - def step(self, action: np.ndarray) -> Tuple[int, float, bool, bool, Dict[str, Any]]: self._steps_called += 1 - if self._terminated: - raise StepCalledAfterEnvTerminatedException + assert not self._terminated observation = 0 reward = 0.0 @@ -155,10 +147,7 @@ def step(self, action: np.ndarray) -> Tuple[int, float, bool, bool, Dict[str, An return observation, reward, self._terminated, truncated, {} - def render(self, mode: str = "human") -> None: - pass - - def close(self): + def render(self) -> None: pass