Skip to content

Commit

Permalink
ModelParallelStrategy for Lightning Trainer
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed May 17, 2024
1 parent cd8acc2 commit d806b64
Show file tree
Hide file tree
Showing 9 changed files with 953 additions and 60 deletions.
48 changes: 28 additions & 20 deletions src/lightning/fabric/strategies/model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,13 @@ def _configure_launcher(self) -> None:
def setup_environment(self) -> None:
super().setup_environment()
self._setup_distributed()
self._setup_device_mesh()
if self._data_parallel_size == "auto":
self._data_parallel_size = self.num_nodes
if self._tensor_parallel_size == "auto":
self._tensor_parallel_size = self.num_processes
self._device_mesh = _setup_device_mesh(
self._data_parallel_size, self._tensor_parallel_size, self.world_size, self.root_device
)

@override
def setup_module(self, module: TModel) -> TModel:
Expand Down Expand Up @@ -303,25 +309,6 @@ def _setup_distributed(self) -> None:
assert self.cluster_environment is not None
_init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout)

def _setup_device_mesh(self) -> None:
from torch.distributed.device_mesh import init_device_mesh

if self._data_parallel_size == "auto":
self._data_parallel_size = self.num_nodes
if self._tensor_parallel_size == "auto":
self._tensor_parallel_size = self.num_processes
if self._data_parallel_size * self._tensor_parallel_size != self.world_size:
raise RuntimeError(
f"The sizes `data_parallel_size={self._data_parallel_size}` and"
f" `tensor_parallel_size={self._tensor_parallel_size}` multiplied should equal the world size"
f" ({self.world_size})."
)
self._device_mesh = init_device_mesh(
device_type=self.root_device.type,
mesh_shape=(self._data_parallel_size, self._tensor_parallel_size),
mesh_dim_names=("data_parallel", "tensor_parallel"),
)

def _get_process_group_backend(self) -> str:
return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device)

Expand Down Expand Up @@ -502,6 +489,27 @@ def _load_checkpoint(
)


def _setup_device_mesh(
data_parallel_size: int,
tensor_parallel_size: int,
world_size: int,
device: torch.device,
) -> "DeviceMesh":
from torch.distributed.device_mesh import init_device_mesh

if data_parallel_size * tensor_parallel_size != world_size:
raise RuntimeError(
f"The sizes `data_parallel_size={data_parallel_size}` and"
f" `tensor_parallel_size={tensor_parallel_size}` multiplied should equal the world size"
f" ({world_size})."
)
return init_device_mesh(
device_type=device.type,
mesh_shape=(data_parallel_size, tensor_parallel_size),
mesh_dim_names=("data_parallel", "tensor_parallel"),
)


def _has_dtensor_modules(module: object) -> TypeGuard[Module]:
from torch.distributed._tensor import DTensor

Expand Down
13 changes: 13 additions & 0 deletions src/lightning/pytorch/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from pathlib import Path
from typing import (
IO,
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand Down Expand Up @@ -76,6 +77,9 @@
OptimizerLRScheduler,
)

if TYPE_CHECKING:
from torch.distributed.device_mesh import DeviceMesh

_ONNX_AVAILABLE = RequirementCache("onnx")

warning_cache = WarningCache()
Expand Down Expand Up @@ -142,6 +146,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._fabric: Optional["lf.Fabric"] = None
self._fabric_optimizers: List[_FabricOptimizer] = []

# access to device mesh in `conigure_model()` hook
self._device_mesh: Optional["DeviceMesh"] = None

@overload
def optimizers(
self, use_pl_optimizer: Literal[True] = True
Expand Down Expand Up @@ -319,6 +326,12 @@ def loggers(self) -> Union[List[Logger], List[FabricLogger]]:
return self._trainer.loggers
return []

@property
def device_mesh(self) -> Optional["DeviceMesh"]:
"""Strategies like ``ModelParallelStrategy`` will create a device mesh that can be accessed in the
:meth:`configure_model` hook to parallelize the LightningModule."""
return self._device_mesh

def _call_batch_hook(self, hook_name: str, *args: Any) -> Any:
trainer = self._trainer
if trainer:
Expand Down
2 changes: 2 additions & 0 deletions src/lightning/pytorch/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from lightning.pytorch.strategies.ddp import DDPStrategy
from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy
from lightning.pytorch.strategies.fsdp import FSDPStrategy
from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy
from lightning.pytorch.strategies.parallel import ParallelStrategy
from lightning.pytorch.strategies.single_device import SingleDeviceStrategy
from lightning.pytorch.strategies.single_xla import SingleDeviceXLAStrategy # noqa: F401
Expand All @@ -31,6 +32,7 @@
"DDPStrategy",
"DeepSpeedStrategy",
"FSDPStrategy",
"ModelParallelStrategy",
"ParallelStrategy",
"SingleDeviceStrategy",
"Strategy",
Expand Down
Loading

0 comments on commit d806b64

Please sign in to comment.