-
Notifications
You must be signed in to change notification settings - Fork 28
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytorch Fully Sharded Data Parallel (FSDP) Integration #147
base: master
Are you sure you want to change the base?
Changes from 7 commits
f3dcd22
c3559e2
9de0b4f
fc8ac00
f2173b4
0c43ec8
26d940b
af121a3
9956d4a
636c45d
3ab379a
66573ec
2f41596
63f7fe9
ea0f866
08fd4d3
856b48e
4140af5
96b39bb
c0c53ca
0f803cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,7 +3,7 @@ | |
from __future__ import annotations | ||
|
||
import types | ||
from typing import Literal, List | ||
from typing import Literal | ||
from functools import lru_cache | ||
|
||
from zeus.utils.logging import get_logger | ||
|
@@ -105,41 +105,43 @@ def sync_execution( | |
|
||
|
||
def all_reduce( | ||
object: List[int] | List[float], operation: Literal["sum", "max"] | ||
) -> int | float: | ||
object: list[int] | list[float], operation: Literal["sum", "max"] | ||
) -> list[int] | list[float]: | ||
"""Reduce objects from all replicas through the specified operation. | ||
|
||
If running in a distributed setting, the objects are reduced across all replicas. | ||
If running in a non-distributed setting, the operation is just done on the single object. | ||
If the current execution is not distributed, the object is returned as is. | ||
""" | ||
if torch_is_available(ensure_cuda=False): | ||
torch = MODULE_CACHE["torch"] | ||
|
||
# wrap object in a tensor if it is not already | ||
if not isinstance(object, torch.Tensor): | ||
object = torch.Tensor(object) | ||
# if torch.distributed is not available or not initialized, return the object as is | ||
if ( | ||
not torch.distributed.is_available() | ||
or not torch.distributed.is_initialized() | ||
): | ||
return object | ||
|
||
# wrap object in a tensor | ||
tensor = torch.Tensor(object).cuda() | ||
|
||
# determine operation | ||
if operation == "sum": | ||
torch_func = torch.sum | ||
torch_op = torch.distributed.ReduceOp.SUM | ||
elif operation == "max": | ||
torch_func = torch.max | ||
torch_op = torch.distributed.ReduceOp.MAX | ||
else: | ||
raise ValueError(f"all_reduce unsupported operation: {operation}") | ||
|
||
# compute local operation | ||
result = torch_func(object) | ||
|
||
# all-reduce only if torch.distributed is available and initialized | ||
if torch.distributed.is_available() and torch.distributed.is_initialized(): | ||
torch.distributed.all_reduce(result.cuda(), op=torch_op) | ||
return result.item() | ||
torch.distributed.all_reduce(tensor, op=torch_op) | ||
return tensor.cpu().tolist() | ||
|
||
if jax_is_available(): | ||
# JAX cross-device all-reduce not yet implemente | ||
return sum(object) if operation == "sum" else max(object) | ||
# Check if not distributed | ||
jax = MODULE_CACHE["jax"] | ||
if jax.process_count() == 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. https://jax.readthedocs.io/en/latest/multi_process.html#running-multi-process-computations There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, it should be. Fixed that, thanks |
||
return object | ||
|
||
raise NotImplementedError("JAX distributed all-reduce not yet implemented") | ||
|
||
raise RuntimeError("No framework is available.") | ||
|
||
|
@@ -150,5 +152,6 @@ def is_distributed() -> bool: | |
torch = MODULE_CACHE["torch"] | ||
return torch.distributed.is_available() and torch.distributed.is_initialized() | ||
if jax_is_available(): | ||
return False # JAX not yet implemented | ||
return False | ||
jax = MODULE_CACHE["jax"] | ||
return jax.process_count() > 1 | ||
raise RuntimeError("No framework is available.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line 265 is now broken, because the previous implementation assumed that
len(zeus_monitor.gpu_indices)
gives the current world size. Let's just switch the defaultoptimum_selector
toMaxSlowdownConstraint(factor=1.1)
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or we could use
torch.distributed.get_world_size
(and something analogous for jax) by defining a generic framework functionzeus.framework.get_world_size
. What do you think?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nah, I wouldn't bother for this one. Now I think
MaxSlowdownConstraint
is a better default; the original one if from 2022.