diff --git a/test/discrete/test_bdq.py b/test/discrete/test_bdq.py index 3fe3a7075..e7abe0b8d 100644 --- a/test/discrete/test_bdq.py +++ b/test/discrete/test_bdq.py @@ -52,7 +52,10 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: env = gym.make(args.task) env = ContinuousToDiscrete(env, args.action_per_branch) - args.state_shape = env.observation_space.shape or env.observation_space.n + if isinstance(env.observation_space, gym.spaces.Box): + args.state_shape = env.observation_space.shape + elif isinstance(env.observation_space, gym.spaces.Discrete): + args.state_shape = int(env.observation_space.n) assert isinstance(env.action_space, gym.spaces.MultiDiscrete) args.num_branches = env.action_space.shape[0] @@ -100,7 +103,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None: model=net, optim=optim, discount_factor=args.gamma, - action_space=env.action_space, + action_space=env.action_space, # type: ignore[arg-type] # TODO: should BranchingPolicy support also `MultiDiscrete` action spaces? target_update_freq=args.target_update_freq, ) # collector