-
Notifications
You must be signed in to change notification settings - Fork 97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat: sebulba rec ippo #1142
base: develop
Are you sure you want to change the base?
Feat: sebulba rec ippo #1142
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall the system looks correct and reasonable. Well done Simon! I just kept minor requests :)
@@ -0,0 +1,910 @@ | |||
# Copyright 2022 InstaDeep Ltd. All rights reserved. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you can update the typings in Pipeline mava/utils/sebulba.py
to be Union[PPOTransition, RNNPPOTransition]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This causes errors in the pre-commit. For now I changed both sebulba systems to use the MavaTransition
type-var but this is probably a temporary solution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please make an issue for this. I think the best solution is to make a protocol with the all the common things in a transition (actions, obs, done, reward). The challenge is that named tuples don't seem to work with protocols so we'd likely need to switch to a flax/chex dataclass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Pretty much good to go except a few minor style changes to bring it up to date with the latest PPO changes that went in at the end of last year
…nto feat/sebulba_rec_ippo
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Final minor comments
mava/systems/ppo/sebulba/rec_ippo.py
Outdated
(config.arch.num_envs, num_agents), config.network.hidden_state_dim | ||
) | ||
hstates = HiddenStates(init_policy_hstate, init_critic_hstate) | ||
hstates_tpu = tree.map(move_to_device, hstates) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device_put
already tree maps
hstates_tpu = tree.map(move_to_device, hstates) | |
hstates_tpu = move_to_device(hstates) |
mava/systems/ppo/sebulba/rec_ippo.py
Outdated
obs_tpu = tree.map(move_to_device, timestep.observation) | ||
last_dones = tree.map(move_to_device, dones) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
obs_tpu = tree.map(move_to_device, timestep.observation) | |
last_dones = tree.map(move_to_device, dones) | |
obs_tpu = move_to_device(timestep.observation) | |
last_dones = move_to_device(dones) |
Sebulba implementation of recurrent IPPO.