Skip to content

Commit

Permalink
Fix wrappers type annotation (need PR from Gymnasium)
Browse files Browse the repository at this point in the history
  • Loading branch information
araffin committed Feb 19, 2023
1 parent 65af7c1 commit ad48559
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 31 deletions.
40 changes: 20 additions & 20 deletions stable_baselines3/common/atari_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from typing import Dict, Tuple
from typing import Dict, SupportsFloat

import gymnasium as gym
import numpy as np
from gymnasium import spaces

from stable_baselines3.common.type_aliases import AtariResetReturn, AtariStepReturn

try:
import cv2 # pytype:disable=import-error

cv2.ocl.setUseOpenCL(False)
except ImportError:
cv2 = None

from stable_baselines3.common.type_aliases import Gym26StepReturn


class StickyActionEnv(gym.Wrapper):
class StickyActionEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Sticky action.
Expand All @@ -30,17 +30,17 @@ def __init__(self, env: gym.Env, action_repeat_probability: float) -> None:
self.action_repeat_probability = action_repeat_probability
assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined]

def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]:
def reset(self, **kwargs) -> AtariResetReturn:
self._sticky_action = 0 # NOOP
return self.env.reset(**kwargs)

def step(self, action: int) -> Gym26StepReturn:
def step(self, action: int) -> AtariStepReturn:
if self.np_random.random() >= self.action_repeat_probability:
self._sticky_action = action
return self.env.step(self._sticky_action)


class NoopResetEnv(gym.Wrapper):
class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Sample initial states by taking random number of no-ops on reset.
No-op is assumed to be action 0.
Expand All @@ -56,7 +56,7 @@ def __init__(self, env: gym.Env, noop_max: int = 30) -> None:
self.noop_action = 0
assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined]

def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]:
def reset(self, **kwargs) -> AtariResetReturn:
self.env.reset(**kwargs)
if self.override_num_noops is not None:
noops = self.override_num_noops
Expand All @@ -72,7 +72,7 @@ def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]:
return obs, info


class FireResetEnv(gym.Wrapper):
class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Take action on reset for environments that are fixed until firing.
Expand All @@ -84,7 +84,7 @@ def __init__(self, env: gym.Env) -> None:
assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined]
assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined]

def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]:
def reset(self, **kwargs) -> AtariResetReturn:
self.env.reset(**kwargs)
obs, _, terminated, truncated, _ = self.env.step(1)
if terminated or truncated:
Expand All @@ -95,7 +95,7 @@ def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]:
return obs, {}


class EpisodicLifeEnv(gym.Wrapper):
class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Make end-of-life == end-of-episode, but only reset on true game over.
Done by DeepMind for the DQN and co. since it helps value estimation.
Expand All @@ -108,7 +108,7 @@ def __init__(self, env: gym.Env) -> None:
self.lives = 0
self.was_real_done = True

def step(self, action: int) -> Gym26StepReturn:
def step(self, action: int) -> AtariStepReturn:
obs, reward, terminated, truncated, info = self.env.step(action)
self.was_real_done = terminated or truncated
# check current lives, make loss of life terminal,
Expand All @@ -122,7 +122,7 @@ def step(self, action: int) -> Gym26StepReturn:
self.lives = lives
return obs, reward, terminated, truncated, info

def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]:
def reset(self, **kwargs) -> AtariResetReturn:
"""
Calls the Gym environment reset, only when lives are exhausted.
This way all states are still reachable even though lives are episodic,
Expand All @@ -146,7 +146,7 @@ def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]:
return obs, info


class MaxAndSkipEnv(gym.Wrapper):
class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Return only every ``skip``-th frame (frameskipping)
and return the max between the two last frames.
Expand All @@ -164,7 +164,7 @@ def __init__(self, env: gym.Env, skip: int = 4) -> None:
self._obs_buffer = np.zeros((2,) + env.observation_space.shape, dtype=env.observation_space.dtype)
self._skip = skip

def step(self, action: int) -> Gym26StepReturn:
def step(self, action: int) -> AtariStepReturn:
"""
Step the environment with the given action
Repeat action, sum reward, and max over last observations.
Expand All @@ -181,7 +181,7 @@ def step(self, action: int) -> Gym26StepReturn:
self._obs_buffer[0] = obs
if i == self._skip - 1:
self._obs_buffer[1] = obs
total_reward += reward
total_reward += float(reward)
if done:
break
# Note that the observation on the done=True frame
Expand All @@ -201,17 +201,17 @@ class ClipRewardEnv(gym.RewardWrapper):
def __init__(self, env: gym.Env) -> None:
super().__init__(env)

def reward(self, reward: float) -> float:
def reward(self, reward: SupportsFloat) -> float:
"""
Bin reward to {+1, 0, -1} by its sign.
:param reward:
:return:
"""
return np.sign(reward)
return np.sign(float(reward))


class WarpFrame(gym.ObservationWrapper):
class WarpFrame(gym.ObservationWrapper[np.ndarray, int, np.ndarray]):
"""
Convert to grayscale and warp frames to 84x84 (default)
as done in the Nature paper and later work.
Expand Down Expand Up @@ -246,7 +246,7 @@ def observation(self, frame: np.ndarray) -> np.ndarray:
return frame[:, :, None]


class AtariWrapper(gym.Wrapper):
class AtariWrapper(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
"""
Atari 2600 preprocessings
Expand Down
2 changes: 1 addition & 1 deletion stable_baselines3/common/env_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _init() -> gym.Env:
kwargs = {"render_mode": "rgb_array"}
kwargs.update(env_kwargs)
try:
env = gym.make(env_id, **kwargs)
env = gym.make(env_id, **kwargs) # type: ignore[arg-type]
except TypeError:
env = gym.make(env_id, **env_kwargs)
else:
Expand Down
16 changes: 7 additions & 9 deletions stable_baselines3/common/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
import os
import time
from glob import glob
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, SupportsFloat, Tuple, Union

import gymnasium as gym
import numpy as np
import pandas
from gymnasium.core import ActType, ObsType

from stable_baselines3.common.type_aliases import Gym26ResetReturn, Gym26StepReturn


class Monitor(gym.Wrapper):
class Monitor(gym.Wrapper[ObsType, ActType, ObsType, ActType]):
"""
A monitor wrapper for Gym environments, it is used to know the episode reward, length, time and other data.
Expand Down Expand Up @@ -46,7 +44,7 @@ def __init__(
env_id = env.spec.id if env.spec is not None else None
self.results_writer = ResultsWriter(
filename,
header={"t_start": self.t_start, "env_id": env_id},
header={"t_start": self.t_start, "env_id": str(env_id)},
extra_keys=reset_keywords + info_keywords,
override_existing=override_existing,
)
Expand All @@ -63,7 +61,7 @@ def __init__(
# extra info about the current episode, that was passed in during reset()
self.current_reset_info: Dict[str, Any] = {}

def reset(self, **kwargs) -> Gym26ResetReturn:
def reset(self, **kwargs) -> Tuple[ObsType, Dict[str, Any]]:
"""
Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True
Expand All @@ -84,7 +82,7 @@ def reset(self, **kwargs) -> Gym26ResetReturn:
self.current_reset_info[key] = value
return self.env.reset(**kwargs)

def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn:
def step(self, action: ActType) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]:
"""
Step the environment with the given action
Expand All @@ -94,7 +92,7 @@ def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn:
if self.needs_reset:
raise RuntimeError("Tried to step environment that needs reset")
observation, reward, terminated, truncated, info = self.env.step(action)
self.rewards.append(reward)
self.rewards.append(float(reward))
if terminated or truncated:
self.needs_reset = True
ep_rew = sum(self.rewards)
Expand Down
4 changes: 3 additions & 1 deletion stable_baselines3/common/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys
from enum import Enum
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, SupportsFloat, Tuple, Union

import gymnasium as gym
import numpy as np
Expand All @@ -18,8 +18,10 @@
GymEnv = Union[gym.Env, vec_env.VecEnv]
GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int]
Gym26ResetReturn = Tuple[GymObs, Dict]
AtariResetReturn = Tuple[np.ndarray, Dict[str, Any]]
GymStepReturn = Tuple[GymObs, float, bool, Dict]
Gym26StepReturn = Tuple[GymObs, float, bool, bool, Dict]
AtariStepReturn = Tuple[np.ndarray, SupportsFloat, bool, bool, Dict[str, Any]]
TensorDict = Dict[Union[str, int], th.Tensor]
OptimizerStateDict = Dict[str, Any]
MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback]
Expand Down

0 comments on commit ad48559

Please sign in to comment.