Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] KeyError: 'action_key requested but more than one key present in the environment' when space.Dict action_space #2716

Closed
JosephDenman opened this issue Jan 24, 2025 · 2 comments · Fixed by #2718
Assignees
Labels
bug Something isn't working

Comments

@JosephDenman
Copy link

JosephDenman commented Jan 24, 2025

Describe the bug

I have a gymnasium environment of the form:

import gymnasium as gym
from gymnasium import spaces

class MyEnvironment(gym.Env):
    def __init__(self, ...):
   
        ...

        self.action_space = spaces.Dict(a=spaces.Discrete(2), b=spaces.Box(0, 1, shape=[1]))
        self.observation_space = spaces.Box(0, 1, shape=[1])

        ...

I'm registering the environment then running:

import gymnasium as gym
from torchrl.envs import GymWrapper

env = GymWrapper(gym.make("MyEnvironment", ...))
env.rollout(40)

But then seeing this error:

KeyError: 'action_key requested but more than one key present in the environment'

And, when I attempt to use a tuple instead:

self.action_space = spaces.Tuple([spaces.Box(-1, 1, shape=[1]),
                                                           spaces.Box(0, 1, shape=[1])])

I see another error:

  File "/project/venv/lib/python3.11/site-packages/torchrl/envs/common.py", line 2636, in rollout
    tensordicts = self._rollout_stop_early(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/project/venv/lib/python3.11/site-packages/torchrl/envs/common.py", line 2730, in _rollout_stop_early
    tensordict = self.step(tensordict)
                 ^^^^^^^^^^^^^^^^^^^^^
  File "/project/venv/lib/python3.11/site-packages/torchrl/envs/common.py", line 1506, in step
    next_tensordict = self._step(tensordict)
                      ^^^^^^^^^^^^^^^^^^^^^^
  File "/project/venv/lib/python3.11/site-packages/torchrl/envs/gym_like.py", line 298, in _step
    action = self.read_action(action)
             ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/project/venv/lib/python3.11/site-packages/torchrl/envs/libs/gym.py", line 917, in read_action
    action = super().read_action(action)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/project/venv/lib/python3.11/site-packages/torchrl/envs/gym_like.py", line 188, in read_action
    return self.action_spec.to_numpy(action, safe=False)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/project/venv/lib/python3.11/site-packages/torchrl/data/tensor_specs.py", line 1287, in to_numpy
    return val.detach().cpu().numpy()
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Internal error: NestedTensorImpl doesn't support sizes. Please file an issue.

The rollout succeeds when the spaces.Box instances have the exact same arguments, but it wasn't clear from the documentation that this was required. Only that they needed to have the same dimension and dtype.

Expected behavior

The documentation says to be cautious when using spaces but also says torch RL supports spaces.Dict and spaces.Tuple. I would expect a rollout with either of these action spaces to succeed. Is there an idiomatic way to represent heterogeneous action spaces?

System info

import torchrl, numpy, sys
print(torchrl.__version__, numpy.__version__, sys.version, sys.platform)

0.6.0 1.26.4 3.11.2 (v3.11.2:878ead1ac1, Feb 7 2023, 10:02:41) [Clang 13.0.0 (clang-1300.0.29.30)] darwin

@JosephDenman JosephDenman added the bug Something isn't working label Jan 24, 2025
@vmoens
Copy link
Contributor

vmoens commented Jan 25, 2025

We don't support tuple spaces atm, I think we could get around that but we haven't worked on it yet.

Can you print the full stack of the key error when using a dict?

@JosephDenman
Copy link
Author

For the action key error, the stack trace is:

Traceback (most recent call last):
  File "/project/run.py", line 33, in <module>
    run()
  File "/project/run.py", line 30, in run
    env.rollout(10)
  File "/project/venv/lib/python3.11/site-packages/torchrl/envs/common.py", line 2635, in rollout
    tensordicts = self._rollout_stop_early(
                  ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/project/venv/lib/python3.11/site-packages/torchrl/envs/common.py", line 2729, in _rollout_stop_early
    tensordict = self.step(tensordict)
                 ^^^^^^^^^^^^^^^^^^^^^
  File "/project/venv/lib/python3.11/site-packages/torchrl/envs/common.py", line 1505, in step
    next_tensordict = self._step(tensordict)
                      ^^^^^^^^^^^^^^^^^^^^^^
  File "/project/venv/lib/python3.11/site-packages/torchrl/envs/gym_like.py", line 296, in _step
    action = tensordict.get(self.action_key)
                            ^^^^^^^^^^^^^^^
  File "/project/venv/lib/python3.11/site-packages/torchrl/envs/common.py", line 693, in action_key
    raise KeyError(
KeyError: 'action_key requested but more than one key present in the environment'

The file where I'm running the environment is run.py:


register(
    id='MyEnvironment',
    entry_point=...,
    disable_env_checker=False
)

def run():
    env = GymWrapper(gym.make("MyEnvironment", ...))
    env.rollout(40)

run()

@vmoens vmoens linked a pull request Jan 26, 2025 that will close this issue
@vmoens vmoens closed this as completed Jan 28, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants