-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytorch Fully Sharded Data Parallel (FSDP) Integration #147
base: master
Are you sure you want to change the base?
Conversation
Please rebase to the current master and push. It's impossible to review with all the changes from past commits. |
2792726
to
3ab379a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @parthraut! I'm requesting some changes. Please let me know if anything is unclear.
zeus/utils/framework.py
Outdated
if jax_is_available(): | ||
# JAX cross-device all-reduce not yet implemente | ||
return sum(object) if operation == "sum" else max(object) | ||
|
||
raise RuntimeError("No framework is available.") | ||
|
||
|
||
def is_distributed() -> bool: | ||
"""Check if the current execution is distributed across multiple devices.""" | ||
if torch_is_available(ensure_cuda=False): | ||
torch = MODULE_CACHE["torch"] | ||
return torch.distributed.is_available() and torch.distributed.is_initialized() | ||
if jax_is_available(): | ||
return False # JAX not yet implemented | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- If this was going to be left unimplemented, it should have raised a
NotImplementedError
instead of silently doing the wrong thing. - This PR will be merged after JAX counterparts are implemented. No need to have a full JAX training script; I'm fine with it being tested manually with a quick script that imports and uses
all_reduce
andis_distributed
with JAX.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will include JAX impl in this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line 265 is now broken, because the previous implementation assumed that len(zeus_monitor.gpu_indices)
gives the current world size. Let's just switch the default optimum_selector
to MaxSlowdownConstraint(factor=1.1)
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or we could use torch.distributed.get_world_size
(and something analogous for jax) by defining a generic framework function zeus.framework.get_world_size
. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nah, I wouldn't bother for this one. Now I think MaxSlowdownConstraint
is a better default; the original one if from 2022.
Co-authored-by: Jae-Won Chung <[email protected]>
zeus/utils/framework.py
Outdated
return sum(object) if operation == "sum" else max(object) | ||
# Check if not distributed | ||
jax = MODULE_CACHE["jax"] | ||
if jax.process_count() == 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
https://jax.readthedocs.io/en/latest/multi_process.html#running-multi-process-computations
Should be jax.device_count()
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, it should be. Fixed that, thanks
Integrates Pytorch Fully Sharded Data Parallel (FSDP) into Zeus. Now,
GlobalPowerLimitOptimizer
performs distributed operations (all-reduce) to ensure it makes the correct power limit decision.zeus.framework.all_reduce
, which currently invokestorch.distributed.all_reduce
if torch is the framework.train_fsdp.py
example and relevant documentationGlobalPowerLimitOptimizer.__init__
zeus.framework.is_distributed
and warn the user if multiple GPUs are being monitored in a distributed context