diff --git a/test/test_libs.py b/test/test_libs.py
index b3ba8d54c3d..8b74a0d5418 100644
--- a/test/test_libs.py
+++ b/test/test_libs.py
@@ -7,6 +7,8 @@
 import importlib.util
 import urllib.error
 
+from gym.core import ObsType
+
 _has_isaac = importlib.util.find_spec("isaacgym") is not None
 
 if _has_isaac:
@@ -23,7 +25,7 @@
 from contextlib import nullcontext
 from pathlib import Path
 from sys import platform
-from typing import Optional, Union
+from typing import Optional, Tuple, Union
 from unittest import mock
 
 import numpy as np
@@ -634,6 +636,76 @@ def test_torchrl_to_gym(self, backend, numpy):
         finally:
             set_gym_backend(gb).set()
 
+    @implement_for("gym", None, "0.26")
+    def test_gym_dict_action_space(self):
+        pytest.skip("tested for gym > 0.26 - no backward issue")
+
+    @implement_for("gym", "0.26", None)
+    def test_gym_dict_action_space(self):  # noqa: F811
+        import gym
+        from gym import Env
+
+        class CompositeActionEnv(Env):
+            def __init__(self):
+                self.action_space = gym.spaces.Dict(
+                    a0=gym.spaces.Discrete(2), a1=gym.spaces.Box(-1, 1)
+                )
+                self.observation_space = gym.spaces.Box(-1, 1)
+
+            def step(self, action):
+                return (0.5, 0.0, False, False, {})
+
+            def reset(
+                self,
+                *,
+                seed: Optional[int] = None,
+                options: Optional[dict] = None,
+            ) -> Tuple[ObsType, dict]:
+                return (0.0, {})
+
+        env = CompositeActionEnv()
+        torchrl_env = GymWrapper(env)
+        assert isinstance(torchrl_env.action_spec, Composite)
+        assert len(torchrl_env.action_keys) == 2
+        r = torchrl_env.rollout(10)
+        assert isinstance(r[0]["a0"], torch.Tensor)
+        assert isinstance(r[0]["a1"], torch.Tensor)
+        assert r[0]["observation"] == 0
+        assert r[1]["observation"] == 0.5
+
+    @implement_for("gymnasium")
+    def test_gym_dict_action_space(self):  # noqa: F811
+        import gymnasium as gym
+        from gymnasium import Env
+
+        class CompositeActionEnv(Env):
+            def __init__(self):
+                self.action_space = gym.spaces.Dict(
+                    a0=gym.spaces.Discrete(2), a1=gym.spaces.Box(-1, 1)
+                )
+                self.observation_space = gym.spaces.Box(-1, 1)
+
+            def step(self, action):
+                return (0.5, 0.0, False, False, {})
+
+            def reset(
+                self,
+                *,
+                seed: Optional[int] = None,
+                options: Optional[dict] = None,
+            ) -> Tuple[ObsType, dict]:
+                return (0.0, {})
+
+        env = CompositeActionEnv()
+        torchrl_env = GymWrapper(env)
+        assert isinstance(torchrl_env.action_spec, Composite)
+        assert len(torchrl_env.action_keys) == 2
+        r = torchrl_env.rollout(10)
+        assert isinstance(r[0]["a0"], torch.Tensor)
+        assert isinstance(r[0]["a1"], torch.Tensor)
+        assert r[0]["observation"] == 0
+        assert r[1]["observation"] == 0.5
+
     @pytest.mark.parametrize(
         "env_name",
         [
diff --git a/torchrl/envs/gym_like.py b/torchrl/envs/gym_like.py
index 07d339761b0..bb849847f3a 100644
--- a/torchrl/envs/gym_like.py
+++ b/torchrl/envs/gym_like.py
@@ -292,7 +292,10 @@ def read_obs(
         return observations
 
     def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
-        action = tensordict.get(self.action_key)
+        if len(self.action_keys) == 1:
+            action = tensordict.get(self.action_key)
+        else:
+            action = tensordict.select(*self.action_keys).to_dict()
         if self._convert_actions_to_numpy:
             action = self.read_action(action)