Skip to content

Commit

Permalink
Fix env checker single-step-env edge case (#1521)
Browse files Browse the repository at this point in the history
Before this change, env checker failed to `reset()` the tested
environment before calling `step()` when checking for `Inf` / `NaN`.
This could cause environments which happened to have only one `step()`
available before the episode was terminated to fail.

This is now fixed.
  • Loading branch information
lutogniew committed May 24, 2023
1 parent 1bfb55d commit 00ebe6a
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 3 deletions.
5 changes: 4 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Changelog
==========

Release 2.0.0a10 (WIP)
Release 2.0.0a11 (WIP)
--------------------------

**Gymnasium support**
Expand Down Expand Up @@ -39,6 +39,8 @@ Bug Fixes:
- Fixed ``VecExtractDictObs`` does not handle terminal observation (@WeberSamuel)
- Set NumPy version to ``>=1.20`` due to use of ``numpy.typing`` (@troiganto)
- Fixed loading DQN changes ``target_update_interval`` (@tobirohrer)
- Fixed env checker to properly reset the env before calling ``step()`` when checking
for ``Inf`` and ``NaN`` (@lutogniew)

Deprecations:
^^^^^^^^^^^^^
Expand Down Expand Up @@ -1346,3 +1348,4 @@ And all the contributors:
@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875
@anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong
@DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO @jonasreiher @npit @WeberSamuel @troiganto
@lutogniew
1 change: 1 addition & 0 deletions stable_baselines3/common/env_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
def _check_nan(env: gym.Env) -> None:
"""Check for Inf and NaN using the VecWrapper."""
vec_env = VecCheckNan(DummyVecEnv([lambda: env]))
vec_env.reset()
for _ in range(10):
action = np.array([env.action_space.sample()])
_, _, _, _ = vec_env.step(action)
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.0.0a10
2.0.0a11
57 changes: 56 additions & 1 deletion tests/test_env_checker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Optional
from typing import Any, Dict, Optional, Tuple

import gymnasium as gym
import numpy as np
Expand Down Expand Up @@ -112,3 +112,58 @@ def step(self, action):
test_env = TestEnv()
with pytest.raises(AssertionError, match=error_message):
check_env(env=test_env)


class StepCalledAfterEnvTerminatedException(Exception):
pass


class LimitedStepsTestEnv(gym.Env):
metadata = {"render_modes": ["human"]}
render_mode = None

action_space = spaces.Discrete(n=2)
observation_space = spaces.Discrete(n=2)

def __init__(self, steps_before_termination: int = 1):
super().__init__()

assert steps_before_termination >= 1
self._steps_before_termination = steps_before_termination

self._steps_called = 0
self._terminated = False

def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[int, Dict]:
super().reset(seed=seed)

self._steps_called = 0
self._terminated = False

return 0, {}

def step(self, action: np.ndarray) -> Tuple[int, float, bool, bool, Dict[str, Any]]:
self._steps_called += 1

if self._terminated:
raise StepCalledAfterEnvTerminatedException

observation = 0
reward = 0.0
self._terminated = self._steps_called >= self._steps_before_termination
truncated = False

return observation, reward, self._terminated, truncated, {}

def render(self, mode: str = "human") -> None:
pass

def close(self):
pass


def test_check_env_single_step_env():
test_env = LimitedStepsTestEnv(steps_before_termination=1)

# This should not throw
check_env(env=test_env, warn=True)

0 comments on commit 00ebe6a

Please sign in to comment.