Skip to content

Commit

Permalink
Add type hints to action/obs spaces:
Browse files Browse the repository at this point in the history
  * Should `BranchingPolicy` support also `MultiDiscrete` action spaces?
  • Loading branch information
dantp-ai committed Mar 27, 2024
1 parent 2255f6e commit 8c5fa77
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions test/discrete/test_bdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
env = gym.make(args.task)
env = ContinuousToDiscrete(env, args.action_per_branch)

args.state_shape = env.observation_space.shape or env.observation_space.n
if isinstance(env.observation_space, gym.spaces.Box):
args.state_shape = env.observation_space.shape
elif isinstance(env.observation_space, gym.spaces.Discrete):
args.state_shape = int(env.observation_space.n)
assert isinstance(env.action_space, gym.spaces.MultiDiscrete)
args.num_branches = env.action_space.shape[0]

Expand Down Expand Up @@ -100,7 +103,7 @@ def test_bdq(args: argparse.Namespace = get_args()) -> None:
model=net,
optim=optim,
discount_factor=args.gamma,
action_space=env.action_space,
action_space=env.action_space, # type: ignore[arg-type] # TODO: should BranchingPolicy support also `MultiDiscrete` action spaces?
target_update_freq=args.target_update_freq,
)
# collector
Expand Down

0 comments on commit 8c5fa77

Please sign in to comment.