-
Notifications
You must be signed in to change notification settings - Fork 97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Feat Sebulba recurrent IQL #1148
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've looked through everything except the system file and it looks good, Sebulba utils especially! Just some relatively minor style changes
mava/utils/config.py
Outdated
# PPO specifique check | ||
if "num_minibatches" in config.system: | ||
assert num_eval_samples % config.system.num_minibatches == 0, ( | ||
f"Number of training samples per evaluator ({num_eval_samples})" | ||
+ f"must be divisible by num_minibatches ({config.system.num_minibatches})." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A thought on this, maybe we can split these up into multiple methods e.g check_num_updates
, check_num_envs
etc. Then have a check_sebulba_config_ppo
, check_anakin_config_ppo
and a check_sebulba_config_iql
which will use the relevant methods?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I split it into base_sebulba_checks and ppo_sebulba_checks. Any more splits feel excessive 🤔
|
||
# todo: remove the ppo dependencies when we make sebulba for other systems |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good point though, maybe there's something we can do about it 🤔
Maybe a protocol like that has action, obs, reward
, not sure if there's any other common attributes?
mava/utils/sebulba.py
Outdated
def __init__( | ||
self, samples_per_insert: float, min_size_to_sample: int, min_diff: float, max_diff: float | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we please add a good doc string here please 🙏
terminated = np.repeat( | ||
terminated[..., np.newaxis], repeats=self.num_agents, axis=-1 | ||
) # (B,) --> (B, N) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this already happen for smax and lbf?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work here! Really minor changes required. Happy to merge this pending some benchmarks
next_timestep = env.step(cpu_action) | ||
|
||
# Prepare the transation | ||
terminal = (1 - timestep.discount[..., 0, jnp.newaxis]).astype(bool) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure we want to remove the agent dim here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dones flag is removed here and added back in the scannedRNN. We could either modify the RNN to handle the dones flag with or without the agent dimension or standardize it by keeping the agent dimension across all scripts. 🤔
target: Array, | ||
) -> Tuple[Array, Metrics]: | ||
# axes switched here to scan over time | ||
hidden_state, obs_term_or_trunc = prep_inputs_to_scannedrnn(obs, term_or_trunc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A general comment, I think this would be a lot easier to read if we used done
to mean term_or_trunc
which I think is a reasonable thing. Would have to make the change in anakin also though :/
""" | ||
|
||
eps = jnp.maximum( | ||
config.system.eps_min, 1 - (t / config.system.eps_decay) * (1 - config.system.eps_min) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be nice if we could set a different decay per actor, although I think that's out of scope for this PR. Maybe if you could make an issue to add in some of the ape-X DQN features that would be great
What?
A recurrent IQL implementation using the Sebulba architecture.
Why?
Offline Sebulba base and non-jax envs in Mava.
How?
Mixed the Sebulba structure from PPO with the learner code from Anakin IQL.