From 27cf925a39e83250452d3f02329942f2e47f44e2 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Sun, 26 Jan 2025 09:12:41 -0800 Subject: [PATCH] Update [ghstack-poisoned] --- test/test_libs.py | 74 +++++++++++++++++++++++++++++++++++++++- torchrl/envs/gym_like.py | 5 ++- 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/test/test_libs.py b/test/test_libs.py index b3ba8d54c3d..8b74a0d5418 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -7,6 +7,8 @@ import importlib.util import urllib.error +from gym.core import ObsType + _has_isaac = importlib.util.find_spec("isaacgym") is not None if _has_isaac: @@ -23,7 +25,7 @@ from contextlib import nullcontext from pathlib import Path from sys import platform -from typing import Optional, Union +from typing import Optional, Tuple, Union from unittest import mock import numpy as np @@ -634,6 +636,76 @@ def test_torchrl_to_gym(self, backend, numpy): finally: set_gym_backend(gb).set() + @implement_for("gym", None, "0.26") + def test_gym_dict_action_space(self): + pytest.skip("tested for gym > 0.26 - no backward issue") + + @implement_for("gym", "0.26", None) + def test_gym_dict_action_space(self): # noqa: F811 + import gym + from gym import Env + + class CompositeActionEnv(Env): + def __init__(self): + self.action_space = gym.spaces.Dict( + a0=gym.spaces.Discrete(2), a1=gym.spaces.Box(-1, 1) + ) + self.observation_space = gym.spaces.Box(-1, 1) + + def step(self, action): + return (0.5, 0.0, False, False, {}) + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ) -> Tuple[ObsType, dict]: + return (0.0, {}) + + env = CompositeActionEnv() + torchrl_env = GymWrapper(env) + assert isinstance(torchrl_env.action_spec, Composite) + assert len(torchrl_env.action_keys) == 2 + r = torchrl_env.rollout(10) + assert isinstance(r[0]["a0"], torch.Tensor) + assert isinstance(r[0]["a1"], torch.Tensor) + assert r[0]["observation"] == 0 + assert r[1]["observation"] == 0.5 + + @implement_for("gymnasium") + def test_gym_dict_action_space(self): # noqa: F811 + import gymnasium as gym + from gymnasium import Env + + class CompositeActionEnv(Env): + def __init__(self): + self.action_space = gym.spaces.Dict( + a0=gym.spaces.Discrete(2), a1=gym.spaces.Box(-1, 1) + ) + self.observation_space = gym.spaces.Box(-1, 1) + + def step(self, action): + return (0.5, 0.0, False, False, {}) + + def reset( + self, + *, + seed: Optional[int] = None, + options: Optional[dict] = None, + ) -> Tuple[ObsType, dict]: + return (0.0, {}) + + env = CompositeActionEnv() + torchrl_env = GymWrapper(env) + assert isinstance(torchrl_env.action_spec, Composite) + assert len(torchrl_env.action_keys) == 2 + r = torchrl_env.rollout(10) + assert isinstance(r[0]["a0"], torch.Tensor) + assert isinstance(r[0]["a1"], torch.Tensor) + assert r[0]["observation"] == 0 + assert r[1]["observation"] == 0.5 + @pytest.mark.parametrize( "env_name", [ diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py index 07d339761b0..bb849847f3a 100644 --- a/torchrl/envs/gym_like.py +++ b/torchrl/envs/gym_like.py @@ -292,7 +292,10 @@ def read_obs( return observations def _step(self, tensordict: TensorDictBase) -> TensorDictBase: - action = tensordict.get(self.action_key) + if len(self.action_keys) == 1: + action = tensordict.get(self.action_key) + else: + action = tensordict.select(*self.action_keys).to_dict() if self._convert_actions_to_numpy: action = self.read_action(action)