Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch Fully Sharded Data Parallel (FSDP) Integration #147

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions examples/power_limit_optimizer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@ You just need to download and extract the ImageNet data and mount it to the Dock

## Multi-GPU Distributed Training (Pytorch DDP and FSDP)

When using `ZeusMonitor` and/or `GlobalPowerLimitOptimizer` in a multi-GPU Distributed context, launch one instance of `ZeusMonitor` and/or `GlobalPowerLimitOptimizer` per local rank (per GPU on each node), and pass in the local rank to `ZeusMonitor` as shown below:
When using `ZeusMonitor` and/or `GlobalPowerLimitOptimizer` in a multi-GPU Distributed context, construct one instance of `ZeusMonitor` and/or `GlobalPowerLimitOptimizer` per local rank (per GPU on each node), and pass in the local rank to `ZeusMonitor` as shown below:

```python
monitor = ZeusMonitor(gpu_indices=[local_rank]) # pass in local rank to gpu_indices.
plo = GlobalPowerLimitOptimizer(monitor)
```

Ensure that only one GPU is monitored per `ZeusMonitor`. Internally, `GlobalPowerLimitOptimizer` performs an [All-Reduce](https://pytorch.org/docs/stable/distributed.html) to synchronize before making a power limit decision.
Ensure that only one GPU is monitored per `ZeusMonitor`. Internally, `GlobalPowerLimitOptimizer` performs an [AllReduce](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/usage/collectives.html) to aggregate time and energy measurements across all GPUs before making a power limit decision.

## Example command

Expand All @@ -59,12 +59,6 @@ torchrun \
--nnodes 1 \
--nproc_per_node=gpu `# Number of processes per node, should be equal to the number of GPUs.` \
train_fsdp.py \
--batch-size 64 `# Batch size for training.` \
--test-batch-size 1000 `# Batch size for testing.` \
--epochs 10 `# Number of epochs to train.` \
--lr 1.0 `# Learning rate.` \
--gamma 0.7 `# Learning rate step gamma.` \
--save-model `# Save the trained model.` \
[DATA_DIR]
```

28 changes: 13 additions & 15 deletions examples/power_limit_optimizer/train_dp.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,21 +197,19 @@ def main():
sampler=val_sampler,
)

# The rank 0 process will monitor and optimize the power limit of all GPUs.
if args.gpu == 0:
callback_set: list[Callback] = [
GlobalPowerLimitOptimizer(
monitor=ZeusMonitor(gpu_indices=args.gpu), # Since there is only one GPU per process, monitor it (give it local rank).
optimum_selector=MaxSlowdownConstraint(
factor=get_env("ZEUS_MAX_SLOWDOWN", float, 1.1),
),
warmup_steps=10,
profile_steps=40,
pl_step=25,
)
]
else:
callback_set = []
# All proceses will monitor and optimize the power limit of all GPUs (one process per GPU).
callback_set: list[Callback] = [
GlobalPowerLimitOptimizer(
monitor=ZeusMonitor(gpu_indices=args.gpu), # Since there is only one GPU per process, monitor it (give it local rank).
optimum_selector=MaxSlowdownConstraint(
factor=get_env("ZEUS_MAX_SLOWDOWN", float, 1.1),
),
warmup_steps=10,
profile_steps=40,
pl_step=25,
)
]

callbacks = CallbackSet(callback_set)

for epoch in range(args.epochs):
Expand Down
1 change: 0 additions & 1 deletion zeus/monitor/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,6 @@ def __init__(
except ZeusCPUInitError:
self.cpus = EmptyCPUs()
except ZeusCPUNoPermissionError as err:
self.cpus = EmptyCPUs()
if cpu_indices:
raise RuntimeError(
"Root privilege is required to read RAPL metrics. See "
Expand Down
33 changes: 17 additions & 16 deletions zeus/optimizer/power_limit.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 265 is now broken, because the previous implementation assumed that len(zeus_monitor.gpu_indices) gives the current world size. Let's just switch the default optimum_selector to MaxSlowdownConstraint(factor=1.1).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we could use torch.distributed.get_world_size (and something analogous for jax) by defining a generic framework function zeus.framework.get_world_size. What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nah, I wouldn't bother for this one. Now I think MaxSlowdownConstraint is a better default; the original one if from 2022.

Original file line number Diff line number Diff line change
Expand Up @@ -203,23 +203,22 @@ class GlobalPowerLimitOptimizer(Callback):

This optimizer uses the JIT profiling log to determine the optimal power limit.

Non-distributed training (Single GPU or Multi-GPU on a single node):
Launch one instance of `ZeusMonitor` and `GlobalPowerLimitOptimizer`, and have `ZeusMonitor` track all desired GPUs.
For example, to track all GPUs on a single node:
```python
monitor = ZeusMonitor(gpu_indices=None) # monitor all GPUs
plo = GlobalPowerLimitOptimizer(monitor)
```
## Usage with distributed data parallelism

The global power limit optimizer expects one process to control each GPU used for training.
For instance, `torchrun` will automatically spawn one process for each GPU on the node.
Correspondingly, the [`ZeusMonitor`][zeus.monitor.energy.ZeusMonitor] instance passed in
should be monitoring **one GPU**: the one being managed by the current process. The index of
this GPU would typically match the local rank of the process. In the case of PyTorch, users would have
called `torch.cuda.set_device` early on, so `torch.cuda.current_device` will give you the GPU index.
`GlobalPowerLimitOptimizer` will internally do an AllReduce across all GPUs to aggregate
time and energy measurements, and then select the globally optimal power limit.


Distributed training (Multi-GPU on multiple nodes):
`ZeusMonitor` and `GlobalPowerLimitOptimizer` make the assumption that each GPU is monitored by one and only one instance of `ZeusMonitor` to ensure correct functionality.
Therefore, it is recommended to launch one instance of `ZeusMonitor` and `GlobalPowerLimitOptimizer`
per device (per GPU on each node), and pass in the local rank to `ZeusMonitor` as shown below:
```python
monitor = ZeusMonitor(gpu_indices=[local_rank]) # pass in local rank to gpu_indices.
monitor = ZeusMonitor(gpu_indices=[local_rank]) # pass in local rank to gpu_indices.
plo = GlobalPowerLimitOptimizer(monitor)
```
Internally, `GlobalPowerLimitOptimizer` performs an all-reduce over all devices to compute the optimal power limit.
"""

def __init__(
Expand Down Expand Up @@ -420,10 +419,12 @@ def on_step_begin(self) -> None:
self.measurements.append(
PowerLimitMeasurement(
power_limit=self.state.current_power_limit // 1000,
energy=all_reduce(
list(measurement.gpu_energy.values()), operation="sum"
energy=sum(
all_reduce(
list(measurement.gpu_energy.values()), operation="sum"
)
),
time=all_reduce([measurement.time], operation="max"),
time=max(all_reduce([measurement.time], operation="max")),
)
)
# If we're done profiling all power limits, compute the optimal
Expand Down
45 changes: 24 additions & 21 deletions zeus/utils/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import types
from typing import Literal, List
from typing import Literal
from functools import lru_cache

from zeus.utils.logging import get_logger
Expand Down Expand Up @@ -105,41 +105,43 @@ def sync_execution(


def all_reduce(
object: List[int] | List[float], operation: Literal["sum", "max"]
) -> int | float:
object: list[int] | list[float], operation: Literal["sum", "max"]
) -> list[int] | list[float]:
"""Reduce objects from all replicas through the specified operation.

If running in a distributed setting, the objects are reduced across all replicas.
If running in a non-distributed setting, the operation is just done on the single object.
If the current execution is not distributed, the object is returned as is.
"""
if torch_is_available(ensure_cuda=False):
torch = MODULE_CACHE["torch"]

# wrap object in a tensor if it is not already
if not isinstance(object, torch.Tensor):
object = torch.Tensor(object)
# if torch.distributed is not available or not initialized, return the object as is
if (
not torch.distributed.is_available()
or not torch.distributed.is_initialized()
):
return object

# wrap object in a tensor
tensor = torch.Tensor(object).cuda()

# determine operation
if operation == "sum":
torch_func = torch.sum
torch_op = torch.distributed.ReduceOp.SUM
elif operation == "max":
torch_func = torch.max
torch_op = torch.distributed.ReduceOp.MAX
else:
raise ValueError(f"all_reduce unsupported operation: {operation}")

# compute local operation
result = torch_func(object)

# all-reduce only if torch.distributed is available and initialized
if torch.distributed.is_available() and torch.distributed.is_initialized():
torch.distributed.all_reduce(result.cuda(), op=torch_op)
return result.item()
torch.distributed.all_reduce(tensor, op=torch_op)
return tensor.cpu().tolist()

if jax_is_available():
# JAX cross-device all-reduce not yet implemente
return sum(object) if operation == "sum" else max(object)
# Check if not distributed
jax = MODULE_CACHE["jax"]
if jax.process_count() == 1:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

@parthraut parthraut Dec 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, it should be. Fixed that, thanks

return object

raise NotImplementedError("JAX distributed all-reduce not yet implemented")

raise RuntimeError("No framework is available.")

Expand All @@ -150,5 +152,6 @@ def is_distributed() -> bool:
torch = MODULE_CACHE["torch"]
return torch.distributed.is_available() and torch.distributed.is_initialized()
if jax_is_available():
return False # JAX not yet implemented
return False
jax = MODULE_CACHE["jax"]
return jax.process_count() > 1
raise RuntimeError("No framework is available.")
Loading