From d806b646a0c7321002b5d520a8a300614f03cad8 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 17 May 2024 15:48:19 +0200 Subject: [PATCH 1/8] ModelParallelStrategy for Lightning Trainer --- .../fabric/strategies/model_parallel.py | 48 +-- src/lightning/pytorch/core/module.py | 13 + src/lightning/pytorch/strategies/__init__.py | 2 + .../pytorch/strategies/model_parallel.py | 318 ++++++++++++++++ .../connectors/accelerator_connector.py | 2 + .../strategies/test_model_parallel.py | 7 +- .../test_model_parallel_integration.py | 72 ++-- .../strategies/test_model_parallel.py | 211 +++++++++++ .../test_model_parallel_integration.py | 340 ++++++++++++++++++ 9 files changed, 953 insertions(+), 60 deletions(-) create mode 100644 src/lightning/pytorch/strategies/model_parallel.py create mode 100644 tests/tests_pytorch/strategies/test_model_parallel.py create mode 100644 tests/tests_pytorch/strategies/test_model_parallel_integration.py diff --git a/src/lightning/fabric/strategies/model_parallel.py b/src/lightning/fabric/strategies/model_parallel.py index 4141ea454ca51..6afef809819f8 100644 --- a/src/lightning/fabric/strategies/model_parallel.py +++ b/src/lightning/fabric/strategies/model_parallel.py @@ -163,7 +163,13 @@ def _configure_launcher(self) -> None: def setup_environment(self) -> None: super().setup_environment() self._setup_distributed() - self._setup_device_mesh() + if self._data_parallel_size == "auto": + self._data_parallel_size = self.num_nodes + if self._tensor_parallel_size == "auto": + self._tensor_parallel_size = self.num_processes + self._device_mesh = _setup_device_mesh( + self._data_parallel_size, self._tensor_parallel_size, self.world_size, self.root_device + ) @override def setup_module(self, module: TModel) -> TModel: @@ -303,25 +309,6 @@ def _setup_distributed(self) -> None: assert self.cluster_environment is not None _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) - def _setup_device_mesh(self) -> None: - from torch.distributed.device_mesh import init_device_mesh - - if self._data_parallel_size == "auto": - self._data_parallel_size = self.num_nodes - if self._tensor_parallel_size == "auto": - self._tensor_parallel_size = self.num_processes - if self._data_parallel_size * self._tensor_parallel_size != self.world_size: - raise RuntimeError( - f"The sizes `data_parallel_size={self._data_parallel_size}` and" - f" `tensor_parallel_size={self._tensor_parallel_size}` multiplied should equal the world size" - f" ({self.world_size})." - ) - self._device_mesh = init_device_mesh( - device_type=self.root_device.type, - mesh_shape=(self._data_parallel_size, self._tensor_parallel_size), - mesh_dim_names=("data_parallel", "tensor_parallel"), - ) - def _get_process_group_backend(self) -> str: return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) @@ -502,6 +489,27 @@ def _load_checkpoint( ) +def _setup_device_mesh( + data_parallel_size: int, + tensor_parallel_size: int, + world_size: int, + device: torch.device, +) -> "DeviceMesh": + from torch.distributed.device_mesh import init_device_mesh + + if data_parallel_size * tensor_parallel_size != world_size: + raise RuntimeError( + f"The sizes `data_parallel_size={data_parallel_size}` and" + f" `tensor_parallel_size={tensor_parallel_size}` multiplied should equal the world size" + f" ({world_size})." + ) + return init_device_mesh( + device_type=device.type, + mesh_shape=(data_parallel_size, tensor_parallel_size), + mesh_dim_names=("data_parallel", "tensor_parallel"), + ) + + def _has_dtensor_modules(module: object) -> TypeGuard[Module]: from torch.distributed._tensor import DTensor diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 3cb55566fb8b7..de9968f340346 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -20,6 +20,7 @@ from pathlib import Path from typing import ( IO, + TYPE_CHECKING, Any, Callable, Dict, @@ -76,6 +77,9 @@ OptimizerLRScheduler, ) +if TYPE_CHECKING: + from torch.distributed.device_mesh import DeviceMesh + _ONNX_AVAILABLE = RequirementCache("onnx") warning_cache = WarningCache() @@ -142,6 +146,9 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._fabric: Optional["lf.Fabric"] = None self._fabric_optimizers: List[_FabricOptimizer] = [] + # access to device mesh in `conigure_model()` hook + self._device_mesh: Optional["DeviceMesh"] = None + @overload def optimizers( self, use_pl_optimizer: Literal[True] = True @@ -319,6 +326,12 @@ def loggers(self) -> Union[List[Logger], List[FabricLogger]]: return self._trainer.loggers return [] + @property + def device_mesh(self) -> Optional["DeviceMesh"]: + """Strategies like ``ModelParallelStrategy`` will create a device mesh that can be accessed in the + :meth:`configure_model` hook to parallelize the LightningModule.""" + return self._device_mesh + def _call_batch_hook(self, hook_name: str, *args: Any) -> Any: trainer = self._trainer if trainer: diff --git a/src/lightning/pytorch/strategies/__init__.py b/src/lightning/pytorch/strategies/__init__.py index 14ffe52870ba5..9c2b2a6a3a621 100644 --- a/src/lightning/pytorch/strategies/__init__.py +++ b/src/lightning/pytorch/strategies/__init__.py @@ -18,6 +18,7 @@ from lightning.pytorch.strategies.ddp import DDPStrategy from lightning.pytorch.strategies.deepspeed import DeepSpeedStrategy from lightning.pytorch.strategies.fsdp import FSDPStrategy +from lightning.pytorch.strategies.model_parallel import ModelParallelStrategy from lightning.pytorch.strategies.parallel import ParallelStrategy from lightning.pytorch.strategies.single_device import SingleDeviceStrategy from lightning.pytorch.strategies.single_xla import SingleDeviceXLAStrategy # noqa: F401 @@ -31,6 +32,7 @@ "DDPStrategy", "DeepSpeedStrategy", "FSDPStrategy", + "ModelParallelStrategy", "ParallelStrategy", "SingleDeviceStrategy", "Strategy", diff --git a/src/lightning/pytorch/strategies/model_parallel.py b/src/lightning/pytorch/strategies/model_parallel.py new file mode 100644 index 0000000000000..e50513f944584 --- /dev/null +++ b/src/lightning/pytorch/strategies/model_parallel.py @@ -0,0 +1,318 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from contextlib import contextmanager, nullcontext +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Mapping, Optional, Union + +import torch +from lightning_utilities.core.rank_zero import rank_zero_only as utils_rank_zero_only +from torch import Tensor +from torch.optim import Optimizer +from typing_extensions import override + +import lightning.pytorch as pl +from lightning.fabric.plugins import CheckpointIO +from lightning.fabric.plugins.collectives.torch_collective import default_pg_timeout +from lightning.fabric.strategies.model_parallel import _setup_device_mesh +from lightning.fabric.utilities.distributed import ( + _distributed_is_initialized, + _get_default_process_group_backend_for_device, + _init_dist_connection, + _sync_ddp_if_available, +) +from lightning.fabric.utilities.distributed import group as _group +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_3 +from lightning.fabric.utilities.init import _materialize_distributed_module +from lightning.fabric.utilities.optimizer import _optimizers_to_device +from lightning.fabric.utilities.seed import reset_seed +from lightning.fabric.utilities.types import _PATH, ReduceOp +from lightning.pytorch.core.optimizer import LightningOptimizer +from lightning.pytorch.strategies.launchers.subprocess_script import _SubprocessScriptLauncher +from lightning.pytorch.strategies.parallel import ParallelStrategy +from lightning.pytorch.strategies.strategy import TBroadcast +from lightning.pytorch.trainer.states import TrainerFn +from lightning.pytorch.utilities.model_helpers import is_overridden +from lightning.pytorch.utilities.rank_zero import rank_zero_only + +if TYPE_CHECKING: + from torch.distributed.device_mesh import DeviceMesh + + +class ModelParallelStrategy(ParallelStrategy): + """Enables user-defined parallelism applied to a model. + + .. warning:: This is an :ref:`experimental ` feature. + + Currently supports up to 2D parallelism. Specifically, it supports the combination of + Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still + experimental in PyTorch. Requires PyTorch 2.3 or newer. + + Arguments: + data_parallel_size: The number of devices within a data-parallel group. Defaults to ``"auto"``, which + sets this size to the number of nodes in the cluster. + tensor_parallel_size: The number of devices within a tensor-parallel group. Defaults to ``"auto"``, which + sets this size to the number of GPUs in a single node. + save_distributed_checkpoint: If ``True``, each rank saves its shard of weights and optimizer states to a file. + The checkpoint is a folder with as many files as the world size. + If ``False``, the full weights and optimizer states get assembled on rank 0 and saved to a single file. + + """ + + def __init__( + self, + data_parallel_size: Union[Literal["auto"], int] = "auto", + tensor_parallel_size: Union[Literal["auto"], int] = "auto", + save_distributed_checkpoint: bool = True, + process_group_backend: Optional[str] = None, + timeout: Optional[timedelta] = default_pg_timeout, + ) -> None: + super().__init__() + if not _TORCH_GREATER_EQUAL_2_3: + raise ImportError(f"{type(self).__name__} requires PyTorch 2.3 or higher.") + self._data_parallel_size = data_parallel_size + self._tensor_parallel_size = tensor_parallel_size + self._save_distributed_checkpoint = save_distributed_checkpoint + self._process_group_backend: Optional[str] = process_group_backend + self._timeout: Optional[timedelta] = timeout + self._device_mesh: Optional["DeviceMesh"] = None + self.num_nodes = 1 + + @property + def device_mesh(self) -> "DeviceMesh": + if self._device_mesh is None: + raise RuntimeError("Accessing the device mesh before processes have initialized is not allowed.") + return self._device_mesh + + @property + @override + def checkpoint_io(self) -> CheckpointIO: + raise NotImplementedError(f"The `{type(self).__name__}` does not use the `CheckpointIO` plugin interface.") + + @checkpoint_io.setter + @override + def checkpoint_io(self, io: CheckpointIO) -> None: + raise NotImplementedError(f"The `{type(self).__name__}` does not support setting a `CheckpointIO` plugin.") + + @property + @override + def root_device(self) -> torch.device: + assert self.parallel_devices is not None + return self.parallel_devices[self.local_rank] + + @property + def num_processes(self) -> int: + return len(self.parallel_devices) if self.parallel_devices is not None else 0 + + @property + @override + def distributed_sampler_kwargs(self) -> Dict[str, Any]: + assert self.device_mesh is not None + data_parallel_mesh = self.device_mesh["data_parallel"] + return {"num_replicas": data_parallel_mesh.size(), "rank": data_parallel_mesh.get_local_rank()} + + @property + def process_group_backend(self) -> Optional[str]: + return self._process_group_backend + + @property + @override + def restore_checkpoint_after_setup(self) -> bool: + return True + + @property + @override + def lightning_restore_optimizer(self) -> bool: + return False + + @override + def _configure_launcher(self) -> None: + assert self.cluster_environment is not None + if not self.cluster_environment.creates_processes_externally: + self._launcher = _SubprocessScriptLauncher(self.cluster_environment, self.num_processes, self.num_nodes) + + @override + def setup_environment(self) -> None: + super().setup_environment() + self._setup_distributed() + if self._data_parallel_size == "auto": + self._data_parallel_size = self.num_nodes + if self._tensor_parallel_size == "auto": + self._tensor_parallel_size = self.num_processes + self._device_mesh = _setup_device_mesh( + self._data_parallel_size, self._tensor_parallel_size, self.world_size, self.root_device + ) + # Users can access device mesh in `LightningModule.configure_model()` + self.lightning_module._device_mesh = self._device_mesh + + @override + def setup(self, trainer: "pl.Trainer") -> None: + from torch.distributed.fsdp import FullyShardedDataParallel + + assert self.accelerator is not None + self.accelerator.setup(trainer) + + if not is_overridden("configure_model", self.lightning_module): + raise TypeError( + f"When using the {type(self).__name__}, you are required to override the `configure_model()` hook in" + f" the LightningModule and apply parallelization there." + ) + if any(isinstance(mod, FullyShardedDataParallel) for mod in self.model.modules()): + raise TypeError( + "Found modules that are wrapped with `torch.distributed.fsdp.FullyShardedDataParallel`." + f" The `{self.__class__.__name__}` only supports the new FSDP2 APIs in PyTorch >= 2.3." + ) + + _materialize_distributed_module(self.model, self.root_device) + + self.model = self.precision_plugin.convert_module(self.model) + self.model_to_device() # move all remaining layers if any left on CPU. + + self.barrier() + + if trainer.state.fn == TrainerFn.FITTING: + self.setup_optimizers(trainer) + self.setup_precision_plugin() + if trainer.state.fn == TrainerFn.FITTING: + _optimizers_to_device(self.optimizers, self.root_device) + + @override + def setup_optimizers(self, trainer: "pl.Trainer") -> None: + # If we're setting up for evaluation after fitting, we need to discard the optimizers + # since we're rewrapping the model, otherwise optimizer param references are no longer valid + # and subsequent checkpoint saving can fail + self._reset_optimizers_and_schedulers() + + return super().setup_optimizers(trainer) + + @override + def model_to_device(self) -> None: + assert self.model is not None + self.model.to(self.root_device) + + @contextmanager + @override + def tensor_init_context(self, empty_init: Optional[bool] = None) -> Generator[None, None, None]: + # Materializaton happens in `setup()` + empty_init_context = torch.device("meta") if empty_init else nullcontext() + with empty_init_context, self.precision_plugin.tensor_init_context(): + yield + + @override + def barrier(self, name: Optional[str] = None) -> None: + if not _distributed_is_initialized(): + return + if torch.distributed.get_backend() == "nccl": + torch.distributed.barrier(device_ids=self._determine_device_ids()) + else: + torch.distributed.barrier() + + @override + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: + if not _distributed_is_initialized(): + return obj + + obj = [obj] + torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) + return obj[0] + + @override + def reduce( + self, + tensor: Union[Tensor, Any], + group: Optional[Any] = None, + reduce_op: Optional[Union[ReduceOp, str]] = "mean", + ) -> Tensor: + if isinstance(tensor, Tensor): + return _sync_ddp_if_available(tensor, group, reduce_op=reduce_op) + return tensor + + def _determine_device_ids(self) -> List[int]: + return [self.root_device.index] + + @override + def teardown(self) -> None: + assert self.cluster_environment is not None + assert self.accelerator is not None + self.cluster_environment.teardown() + self.precision_plugin.teardown() + self.accelerator.teardown() + + @override + def lightning_module_state_dict(self) -> Dict[str, Any]: + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict + + state_dict_options = StateDictOptions(full_state_dict=(not self._save_distributed_checkpoint), cpu_offload=True) + assert self.model is not None + return get_model_state_dict(self.model, options=state_dict_options) + + @override + def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = True) -> None: + # Override to do nothing, the strategy already loaded the states in `load_checkpoint()` + pass + + @override + def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: + from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + from torch.distributed.fsdp import OptimStateKeyType + + state_dict_options = StateDictOptions(full_state_dict=(not self._save_distributed_checkpoint), cpu_offload=True) + if isinstance(optimizer, LightningOptimizer): + optimizer = optimizer._optimizer + + assert self.model is not None + state_dict = get_optimizer_state_dict(self.model, optimizer, options=state_dict_options) + if not self._save_distributed_checkpoint: + # Store the optimizer state dict in standard format + state_dict = FSDP.rekey_optim_state_dict(state_dict, OptimStateKeyType.PARAM_ID, self.model) + return state_dict + + @override + def load_optimizer_state_dict(self, checkpoint: Mapping[str, Any]) -> None: + # Override to do nothing, the strategy already loaded the states in `load_checkpoint()` + pass + + @override + def save_checkpoint( + self, checkpoint: Dict[str, Any], filepath: _PATH, storage_options: Optional[Any] = None + ) -> None: + if storage_options is not None: + raise TypeError( + f"`{type(self).__name__}.save_checkpoint(..., storage_options=...)` is not supported because" + f" `{type(self).__name__}` does not use the `CheckpointIO`." + ) + raise NotImplementedError("Checkpoint saving is not yet implemented.") + + @override + def load_checkpoint(self, checkpoint_path: _PATH) -> Dict[str, Any]: + raise NotImplementedError("Checkpoint loading is not yet implemented.") + + def _setup_distributed(self) -> None: + super().setup_environment() + reset_seed() + self.set_world_ranks() + self._process_group_backend = self._get_process_group_backend() + assert self.cluster_environment is not None + _init_dist_connection(self.cluster_environment, self._process_group_backend, timeout=self._timeout) + + def _get_process_group_backend(self) -> str: + return self._process_group_backend or _get_default_process_group_backend_for_device(self.root_device) + + def set_world_ranks(self) -> None: + if self.cluster_environment is not None: + self.cluster_environment.set_global_rank(self.node_rank * self.num_processes + self.local_rank) + self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) + # `LightningEnvironment.set_global_rank` will do this too, but we cannot rely on that implementation detail + # additionally, for some implementations, the setter is a no-op, so it's safer to access the getter + rank_zero_only.rank = utils_rank_zero_only.rank = self.global_rank diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index a191859c06c43..6a350030ea0f7 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -53,6 +53,7 @@ DDPStrategy, DeepSpeedStrategy, FSDPStrategy, + ModelParallelStrategy, ParallelStrategy, SingleDeviceStrategy, SingleDeviceXLAStrategy, @@ -600,6 +601,7 @@ def is_distributed(self) -> bool: DDPStrategy, FSDPStrategy, DeepSpeedStrategy, + ModelParallelStrategy, XLAStrategy, ] if _habana_available_and_importable(): diff --git a/tests/tests_fabric/strategies/test_model_parallel.py b/tests/tests_fabric/strategies/test_model_parallel.py index 54efb999a8cf1..03b9268b3158e 100644 --- a/tests/tests_fabric/strategies/test_model_parallel.py +++ b/tests/tests_fabric/strategies/test_model_parallel.py @@ -118,8 +118,7 @@ def test_parallelize_fn_call(): @RunIf(min_torch="2.3") def test_no_backward_sync(): - """Test that the backward sync control calls `.no_sync()`, and only on a module wrapped in - FullyShardedDataParallel.""" + """Test that the backward sync control disables gradient sync on modules that benefit from it.""" from torch.distributed._composable.fsdp import FSDP strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m)) @@ -141,7 +140,7 @@ def test_no_backward_sync(): @RunIf(min_torch="2.3") def test_save_checkpoint_storage_options(tmp_path): - """Test that the FSDP strategy does not accept storage options for saving checkpoints.""" + """Test that the strategy does not accept storage options for saving checkpoints.""" strategy = ModelParallelStrategy(parallelize_fn=(lambda m, _: m)) with pytest.raises( TypeError, match=escape("ModelParallelStrategy.save_checkpoint(..., storage_options=...)` is not") @@ -326,7 +325,7 @@ def test_load_raw_checkpoint_optimizer_unsupported(tmp_path): @RunIf(min_torch="2.3") -@mock.patch("lightning.fabric.strategies.ModelParallelStrategy._setup_device_mesh") +@mock.patch("lightning.fabric.strategies.model_parallel._setup_device_mesh") @mock.patch("torch.distributed.init_process_group") def test_set_timeout(init_process_group_mock, _): """Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function.""" diff --git a/tests/tests_fabric/strategies/test_model_parallel_integration.py b/tests/tests_fabric/strategies/test_model_parallel_integration.py index 1f12822c69ee6..1562d8d785263 100644 --- a/tests/tests_fabric/strategies/test_model_parallel_integration.py +++ b/tests/tests_fabric/strategies/test_model_parallel_integration.py @@ -29,42 +29,6 @@ from tests_fabric.helpers.runif import RunIf -@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=4) -def test_setup_device_mesh(): - from torch.distributed.device_mesh import DeviceMesh - - for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)): - strategy = ModelParallelStrategy( - parallelize_fn=(lambda m, _: m), - data_parallel_size=dp_size, - tensor_parallel_size=tp_size, - ) - fabric = Fabric(accelerator="auto", devices=4, strategy=strategy) - fabric.launch() - - device_mesh = fabric.strategy.device_mesh - assert isinstance(device_mesh, DeviceMesh) - assert device_mesh.device_type == fabric.device.type - assert device_mesh.mesh_dim_names == ("data_parallel", "tensor_parallel") - assert device_mesh.size(0) == dp_size - assert device_mesh.size(1) == tp_size - assert device_mesh.ndim == 2 - - fabric.barrier() - - # Passing "auto" will select internode and intranode dimensions automatically - strategy = ModelParallelStrategy( - parallelize_fn=(lambda m, _: m), - data_parallel_size="auto", - tensor_parallel_size="auto", - ) - fabric = Fabric(accelerator="auto", devices=4, num_nodes=1, strategy=strategy) - fabric.launch() - assert fabric.strategy.device_mesh.mesh_dim_names == ("data_parallel", "tensor_parallel") - assert fabric.strategy.device_mesh.size(0) == 1 - assert fabric.strategy.device_mesh.size(1) == 4 - - class FeedForward(nn.Module): def __init__(self): super().__init__() @@ -113,6 +77,42 @@ def _parallelize_feed_forward_fsdp2_tp(model, device_mesh): return model +@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=4) +def test_setup_device_mesh(): + from torch.distributed.device_mesh import DeviceMesh + + for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)): + strategy = ModelParallelStrategy( + parallelize_fn=(lambda m, _: m), + data_parallel_size=dp_size, + tensor_parallel_size=tp_size, + ) + fabric = Fabric(accelerator="auto", devices=4, strategy=strategy) + fabric.launch() + + device_mesh = fabric.strategy.device_mesh + assert isinstance(device_mesh, DeviceMesh) + assert device_mesh.device_type == fabric.device.type + assert device_mesh.mesh_dim_names == ("data_parallel", "tensor_parallel") + assert device_mesh.size(0) == dp_size + assert device_mesh.size(1) == tp_size + assert device_mesh.ndim == 2 + + fabric.barrier() + + # Passing "auto" will select internode and intranode dimensions automatically + strategy = ModelParallelStrategy( + parallelize_fn=(lambda m, _: m), + data_parallel_size="auto", + tensor_parallel_size="auto", + ) + fabric = Fabric(accelerator="auto", devices=4, num_nodes=1, strategy=strategy) + fabric.launch() + assert fabric.strategy.device_mesh.mesh_dim_names == ("data_parallel", "tensor_parallel") + assert fabric.strategy.device_mesh.size(0) == 1 + assert fabric.strategy.device_mesh.size(1) == 4 + + @RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=2) def test_tensor_parallel(): from torch.distributed._tensor import DTensor diff --git a/tests/tests_pytorch/strategies/test_model_parallel.py b/tests/tests_pytorch/strategies/test_model_parallel.py new file mode 100644 index 0000000000000..15e492882e254 --- /dev/null +++ b/tests/tests_pytorch/strategies/test_model_parallel.py @@ -0,0 +1,211 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import timedelta +from re import escape +from unittest import mock +from unittest.mock import Mock + +import pytest +import torch +import torch.nn as nn +from lightning import LightningModule +from lightning.pytorch.plugins.environments import LightningEnvironment +from lightning.pytorch.strategies import ModelParallelStrategy + +from tests_pytorch.helpers.runif import RunIf + + +@mock.patch("lightning.pytorch.strategies.model_parallel._TORCH_GREATER_EQUAL_2_3", False) +def test_torch_greater_equal_2_3(): + with pytest.raises(ImportError, match="ModelParallelStrategy requires PyTorch 2.3 or higher"): + ModelParallelStrategy() + + +@RunIf(min_torch="2.3") +def test_device_mesh_access(): + strategy = ModelParallelStrategy() + with pytest.raises(RuntimeError, match="Accessing the device mesh .* not allowed"): + _ = strategy.device_mesh + + +@RunIf(min_torch="2.3") +@pytest.mark.parametrize( + ("num_nodes", "devices", "invalid_dp_size", "invalid_tp_size"), + [ + (1, 4, 1, 1), + (1, 4, 2, 3), + (1, 4, 4, 2), + (2, 4, 1, 4), + (2, 4, 2, 1), + ], +) +def test_validate_device_mesh_dimensions(num_nodes, devices, invalid_dp_size, invalid_tp_size): + """Test passing sizes that don't multiply to the world size raises an error.""" + strategy = ModelParallelStrategy( + data_parallel_size=invalid_dp_size, + tensor_parallel_size=invalid_tp_size, + ) + strategy._setup_distributed = Mock() + strategy._accelerator = Mock() + strategy.cluster_environment = Mock( + world_size=Mock(return_value=(num_nodes * devices)), local_rank=Mock(return_value=1) + ) + strategy.parallel_devices = [torch.device("cpu")] * devices + strategy.num_nodes = num_nodes + with pytest.raises(RuntimeError, match="multiplied should equal the world size"): + strategy.setup_environment() + + +@RunIf(min_torch="2.3") +def test_checkpoint_io_unsupported(): + """Test that the ModelParallel strategy does not support the `CheckpointIO` plugin.""" + strategy = ModelParallelStrategy() + with pytest.raises(NotImplementedError, match="does not use the `CheckpointIO` plugin"): + _ = strategy.checkpoint_io + + with pytest.raises(NotImplementedError, match="does not support setting a `CheckpointIO` plugin"): + strategy.checkpoint_io = Mock() + + +@RunIf(min_torch="2.3") +def test_fsdp_v1_modules_unsupported(): + """Test that the strategy won't allow setting up a module wrapped with the legacy FSDP API.""" + from torch.distributed.fsdp import FullyShardedDataParallel + + class Model(LightningModule): + def configure_model(self): + pass + + model = Model() + model.modules = Mock(return_value=[Mock(spec=FullyShardedDataParallel)]) + strategy = ModelParallelStrategy() + strategy.model = model + strategy._lightning_module = model + strategy._accelerator = Mock() + + with pytest.raises(TypeError, match="only supports the new FSDP2 APIs in PyTorch >= 2.3"): + strategy.setup(Mock()) + + +@RunIf(min_torch="2.3") +def test_configure_model_required(): + class Model1(LightningModule): + pass + + class Model2(LightningModule): + def configure_model(self): + pass + + model = Model1() + strategy = ModelParallelStrategy() + strategy.model = model + strategy._lightning_module = model + strategy._accelerator = Mock() + strategy._parallel_devices = [torch.device("cpu")] + + with pytest.raises(TypeError, match="you are required to override the `configure_model"): + strategy.setup(Mock()) + + model = Model2() + strategy.model = model + strategy._lightning_module = model + strategy.setup(Mock()) + + +@RunIf(min_torch="2.3") +def test_save_checkpoint_storage_options(tmp_path): + """Test that the strategy does not accept storage options for saving checkpoints.""" + strategy = ModelParallelStrategy() + with pytest.raises( + TypeError, match=escape("ModelParallelStrategy.save_checkpoint(..., storage_options=...)` is not") + ): + strategy.save_checkpoint(checkpoint=Mock(), filepath=tmp_path, storage_options=Mock()) + + +@RunIf(min_torch="2.3") +def test_save_checkpoint_path_exists(): + pytest.skip("Checkpoint saving and loading not implemented") + + +@RunIf(min_torch="2.3") +def test_load_full_checkpoint_support(): + pytest.skip("Checkpoint saving and loading not implemented") + + +@RunIf(min_torch="2.3") +def test_load_unknown_checkpoint_type(): + pytest.skip("Checkpoint saving and loading not implemented") + + +@RunIf(min_torch="2.3") +@mock.patch("lightning.pytorch.strategies.model_parallel._setup_device_mesh") +@mock.patch("torch.distributed.init_process_group") +def test_set_timeout(init_process_group_mock, _): + """Test that the timeout gets passed to the ``torch.distributed.init_process_group`` function.""" + test_timedelta = timedelta(seconds=30) + strategy = ModelParallelStrategy(timeout=test_timedelta) + strategy._lightning_module = Mock() + strategy.parallel_devices = [torch.device("cpu")] + strategy.cluster_environment = LightningEnvironment() + strategy.accelerator = Mock() + strategy.setup_environment() + process_group_backend = strategy._get_process_group_backend() + global_rank = strategy.cluster_environment.global_rank() + world_size = strategy.cluster_environment.world_size() + init_process_group_mock.assert_called_with( + process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta + ) + + +@RunIf(min_torch="2.3") +def test_meta_device_materialization(): + """Test that the `setup()` method materializes meta-device tensors in the LightningModule.""" + + class NoResetParameters(nn.Module): + def __init__(self): + super().__init__() + self.weight = nn.Parameter(torch.ones(4, 4)) + + class CustomModel(LightningModule): + def __init__(self): + super().__init__() + # nn.Sequential as a parameterless module + self.layer1 = nn.Sequential(NoResetParameters(), NoResetParameters()) + self.layer2 = nn.Linear(4, 4) + self.register_buffer("buffer", torch.rand(2)) + + def reset_parameters(self): + self.buffer.fill_(1.0) + + def configure_model(self) -> None: + pass + + with torch.device("meta"): + model = CustomModel() + assert model.layer1[0].weight.is_meta + assert model.layer2.weight.is_meta + assert model.buffer.is_meta + + strategy = ModelParallelStrategy() + strategy._accelerator = Mock() + strategy._device_mesh = Mock() + strategy._parallel_devices = [torch.device("cpu")] + strategy._lightning_module = model + strategy.model = model + + with pytest.warns(UserWarning, match=r"`reset_parameters\(\)` method for re-initialization: NoResetParameters"): + strategy.setup(Mock()) + assert all(not p.is_meta for p in model.parameters()) + assert all(not b.is_meta for b in model.buffers()) diff --git a/tests/tests_pytorch/strategies/test_model_parallel_integration.py b/tests/tests_pytorch/strategies/test_model_parallel_integration.py new file mode 100644 index 0000000000000..86a62a4e77dc6 --- /dev/null +++ b/tests/tests_pytorch/strategies/test_model_parallel_integration.py @@ -0,0 +1,340 @@ +# Copyright The Lightning AI team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F +from lightning.pytorch import LightningModule, Trainer, seed_everything +from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset +from lightning.pytorch.strategies import ModelParallelStrategy +from torch.utils.data import DataLoader, DistributedSampler +from torchmetrics.classification import Accuracy + +from tests_pytorch.helpers.runif import RunIf + + +class FeedForward(nn.Module): + def __init__(self): + super().__init__() + self.w1 = nn.Linear(32, 64) + self.w2 = nn.Linear(32, 64) + self.w3 = nn.Linear(64, 32) + + def forward(self, x): + return self.w3(F.silu(self.w1(x)) * self.w2(x)) + + +def _parallelize_feed_forward_tp(model, device_mesh): + from torch.distributed.tensor.parallel import ColwiseParallel, RowwiseParallel, parallelize_module + + tp_mesh = device_mesh["tensor_parallel"] + tp_plan = { + "w1": ColwiseParallel(), + "w2": ColwiseParallel(), + "w3": RowwiseParallel(), + } + parallelize_module(model, tp_mesh, tp_plan) + return model + + +def _parallelize_feed_forward_fsdp2(model, device_mesh): + from torch.distributed._composable.fsdp.fully_shard import fully_shard + + dp_mesh = device_mesh["data_parallel"] + assert dp_mesh.ndim == 1 # Hybrid-sharding not supported + + # Fully-shard each layer + fully_shard(model.w1, mesh=dp_mesh) + fully_shard(model.w2, mesh=dp_mesh) + fully_shard(model.w3, mesh=dp_mesh) + + # TODO: Re-enable activation checkpointing + # Currently, state dict keys get prefixed with '_checkpoint_wrapper' in the keys + # which leads to mismatches when loading weights into a checkpoint-wrapped module. + # PyTorch should handle this automatically. + + # model = checkpoint_wrapper(model) + + return model + + +def _parallelize_feed_forward_fsdp2_tp(model, device_mesh): + model = _parallelize_feed_forward_tp(model, device_mesh) + model = _parallelize_feed_forward_fsdp2(model, device_mesh) + return model + + +class TemplateModel(LightningModule): + def __init__(self): + super().__init__() + self.model = FeedForward() + + def training_step(self, batch): + output = self.model(batch) + return output.sum() + + def train_dataloader(self): + dataset_size = 8 + dataset = RandomDataset(32, dataset_size) + return DataLoader(dataset, batch_size=2) + + def configure_optimizers(self): + return torch.optim.AdamW(self.model.parameters()) + + +class FSDP2Model(TemplateModel): + def configure_model(self): + _parallelize_feed_forward_fsdp2(self.model, device_mesh=self.device_mesh) + + +class TensorParallelModel(TemplateModel): + def configure_model(self): + _parallelize_feed_forward_tp(self.model, device_mesh=self.device_mesh) + + +class FSDP2TensorParallelModel(TemplateModel): + def configure_model(self): + _parallelize_feed_forward_fsdp2_tp(self.model, device_mesh=self.device_mesh) + + +@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=4) +def test_setup_device_mesh(): + from torch.distributed.device_mesh import DeviceMesh + + for dp_size, tp_size in ((1, 4), (4, 1), (2, 2)): + strategy = ModelParallelStrategy( + data_parallel_size=dp_size, + tensor_parallel_size=tp_size, + ) + trainer = Trainer( + accelerator="auto", + devices=4, + strategy=strategy, + logger=False, + enable_checkpointing=False, + max_steps=1, + ) + + class Model(BoringModel): + def configure_model(self): + device_mesh = self.device_mesh + assert isinstance(device_mesh, DeviceMesh) + assert device_mesh.device_type == model.device.type + assert device_mesh.mesh_dim_names == ("data_parallel", "tensor_parallel") + assert device_mesh.size(0) == dp_size + assert device_mesh.size(1) == tp_size + assert device_mesh.ndim == 2 + + model = Model() + trainer.fit(model) + + # Passing "auto" will select internode and intranode dimensions automatically + strategy = ModelParallelStrategy( + data_parallel_size="auto", + tensor_parallel_size="auto", + ) + trainer = Trainer( + accelerator="auto", + devices=4, + num_nodes=1, + strategy=strategy, + logger=False, + enable_checkpointing=False, + max_steps=1, + ) + + class Model(BoringModel): + def configure_model(self): + device_mesh = self.device_mesh + assert device_mesh.mesh_dim_names == ("data_parallel", "tensor_parallel") + assert device_mesh.size(0) == 1 + assert device_mesh.size(1) == 4 + + model = Model() + trainer.fit(model) + + +@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=2) +def test_tensor_parallel(): + from torch.distributed._tensor import DTensor + + class Model(TensorParallelModel): + def on_train_start(self): + device_mesh = self.device_mesh + optimizer = self.optimizers() + assert all( + tensor.device_mesh == device_mesh["tensor_parallel"] for tensor in optimizer.param_groups[0]["params"] + ) + assert all(isinstance(weight, DTensor) for weight in self.model.parameters()) + assert self.model.w1.weight.device_mesh == device_mesh["tensor_parallel"] + + # No data sharding, all GPUs get the same input inside a TP group + dataloader = self.trainer.train_dataloader + assert len(dataloader) == 8 // dataloader.batch_size + assert isinstance(dataloader.sampler, DistributedSampler) + + def training_step(self, batch): + # All batches must be identical across TP group + batches = self.all_gather(batch) + assert all(torch.equal(batches[0], batches[i]) for i in range(1, len(batches))) + return super().training_step(batch) + + trainer = Trainer( + accelerator="auto", + devices=2, + strategy=ModelParallelStrategy(), + max_steps=2, + enable_checkpointing=False, + logger=False, + ) + + seed_everything(0) + with trainer.init_module(empty_init=True): + model = Model() + + trainer.fit(model) + + +@RunIf(min_torch="2.3", standalone=True, min_cuda_gpus=4) +def test_fsdp2_tensor_parallel(): + from torch.distributed._tensor import DTensor + + class Model(FSDP2TensorParallelModel): + def on_train_start(self): + optimizer = self.optimizers() + assert all(isinstance(weight, DTensor) for weight in self.model.parameters()) + assert all(isinstance(tensor, DTensor) for tensor in optimizer.param_groups[0]["params"]) + assert self.model.w1.weight.device_mesh.ndim == 2 + assert self.model.w1.weight.device_mesh.size(0) == 2 + assert self.model.w1.weight.device_mesh.size(1) == 2 + assert all(weight.device.type != "meta" for weight in self.model.parameters()) + assert all(tensor.device_mesh.ndim == 2 for tensor in optimizer.param_groups[0]["params"]) + assert all(tensor.device.type != "meta" for tensor in optimizer.param_groups[0]["params"]) + + # No data sharding across TP dimension, sharding across data-parallel dimension only + device_mesh = self.device_mesh + dp_mesh = device_mesh["data_parallel"] + dataloader = self.trainer.train_dataloader + assert len(dataloader) == 8 // dataloader.batch_size // dp_mesh.size() + assert isinstance(dataloader.sampler, DistributedSampler) + + def training_step(self, batch): + batches = self.all_gather(batch) + dp_mesh = self.device_mesh["data_parallel"] + tp_mesh = self.device_mesh["tensor_parallel"] + + # Batches across the TP dimension must be identical + batches_tp = batches[tp_mesh.mesh] + assert all(torch.equal(batches_tp[0], batches_tp[i]) for i in range(1, len(batches_tp))) + # Batches across the DP dimension must be different + batches_dp = batches[dp_mesh.mesh] + assert all(not torch.equal(batches_dp[0], batches_dp[i]) for i in range(1, len(batches_dp))) + + return super().training_step(batch) + + strategy = ModelParallelStrategy( + data_parallel_size=2, + tensor_parallel_size=2, + ) + trainer = Trainer( + accelerator="auto", + devices=4, + strategy=strategy, + max_steps=2, + enable_checkpointing=False, + logger=False, + ) + + seed_everything(0) + with trainer.init_module(empty_init=True): + model = Model() + + trainer.fit(model) + + +@RunIf(min_torch="2.3", min_cuda_gpus=2) +def test_modules_without_parameters(tmp_path): + """Test that TorchMetrics get moved to the device despite not having any parameters.""" + + class MetricsModel(TensorParallelModel): + def __init__(self): + super().__init__() + self.metric = Accuracy("multiclass", num_classes=10) + assert self.metric.device == self.metric.tp.device == torch.device("cpu") + + def setup(self, stage) -> None: + assert self.metric.device == self.metric.tp.device == torch.device("cpu") + + def training_step(self, batch): + assert self.metric.device.type == self.metric.tp.device.type == "cuda" + self.metric(torch.rand(2, 10, device=self.device), torch.randint(0, 10, size=(2,), device=self.device)) + return super().training_step(batch) + + model = MetricsModel() + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cuda", + devices=2, + strategy=ModelParallelStrategy(), + max_steps=1, + enable_checkpointing=False, + logger=False, + ) + trainer.fit(model) + + +@RunIf(min_torch="2.3", min_cuda_gpus=2, standalone=True) +@pytest.mark.parametrize( + ("precision", "expected_dtype"), + [ + ("32-true", torch.float32), + ("16-true", torch.float16), + pytest.param("bf16-true", torch.bfloat16, marks=RunIf(bf16_cuda=True)), + ], +) +def test_module_init_context(precision, expected_dtype, tmp_path): + """Test that the module under the init-context gets moved to the right device and dtype.""" + + class Model(FSDP2Model): + def on_train_start(self): + assert self.model.w1.weight.device == torch.device("cuda", self.local_rank) + assert self.model.w1.weight.dtype == expected_dtype + optimizer = self.optimizers(use_pl_optimizer=False) + assert optimizer.param_groups[0]["params"][0].device.type == "cuda" + + def _run_setup_assertions(empty_init, expected_device): + trainer = Trainer( + default_root_dir=tmp_path, + accelerator="cuda", + devices=2, + strategy=ModelParallelStrategy(), + precision=precision, + max_steps=1, + barebones=True, + enable_checkpointing=False, + logger=False, + ) + with trainer.init_module(empty_init=empty_init): + model = Model() + + # The model is on the CPU/meta-device until after `ModelParallelStrategy.setup()` + assert model.model.w1.weight.device == expected_device + assert model.model.w1.weight.dtype == expected_dtype + trainer.fit(model) + + # Case 1: No empty init + _run_setup_assertions(empty_init=False, expected_device=torch.device("cpu")) + + # Case 2: Empty-init with PyTorch >= 2.1 supports meta device + _run_setup_assertions(empty_init=True, expected_device=torch.device("meta")) From 3a54bf75f779815306e6757f70f680bf29fca1c7 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 17 May 2024 15:52:31 +0200 Subject: [PATCH 2/8] mypy --- src/lightning/pytorch/strategies/model_parallel.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/strategies/model_parallel.py b/src/lightning/pytorch/strategies/model_parallel.py index e50513f944584..58a9d2e682fdb 100644 --- a/src/lightning/pytorch/strategies/model_parallel.py +++ b/src/lightning/pytorch/strategies/model_parallel.py @@ -153,12 +153,14 @@ def setup_environment(self) -> None: self._data_parallel_size, self._tensor_parallel_size, self.world_size, self.root_device ) # Users can access device mesh in `LightningModule.configure_model()` + assert self.lightning_module is not None self.lightning_module._device_mesh = self._device_mesh @override def setup(self, trainer: "pl.Trainer") -> None: from torch.distributed.fsdp import FullyShardedDataParallel + assert self.model is not None assert self.accelerator is not None self.accelerator.setup(trainer) @@ -262,7 +264,7 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr pass @override - def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: + def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Any]: from torch.distributed.checkpoint.state_dict import StateDictOptions, get_optimizer_state_dict from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.distributed.fsdp import OptimStateKeyType From 76e206da50a6362aba0197d4cabbb23be15ab8df Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 17 May 2024 15:53:35 +0200 Subject: [PATCH 3/8] import fix --- tests/tests_pytorch/strategies/test_model_parallel.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/strategies/test_model_parallel.py b/tests/tests_pytorch/strategies/test_model_parallel.py index 15e492882e254..4b9b0887c85bf 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel.py +++ b/tests/tests_pytorch/strategies/test_model_parallel.py @@ -20,7 +20,7 @@ import pytest import torch import torch.nn as nn -from lightning import LightningModule +from lightning.pytorch import LightningModule from lightning.pytorch.plugins.environments import LightningEnvironment from lightning.pytorch.strategies import ModelParallelStrategy From a7937b614126aca3e987aa0cc1306e21587b661e Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 17 May 2024 16:17:46 +0200 Subject: [PATCH 4/8] fix torchscript errors --- src/lightning/pytorch/core/module.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index de9968f340346..b1ae2f5af3f6c 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -114,6 +114,7 @@ class LightningModule( "trainer", "fabric", "strict_loading", + "device_mesh" ] + _DeviceDtypeModuleMixin.__jit_unused_properties__ + HyperparametersMixin.__jit_unused_properties__ From 939d5e9a1acee16217503f5cfe4decfbf4e5271c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 May 2024 14:18:24 +0000 Subject: [PATCH 5/8] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/core/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index b1ae2f5af3f6c..d653c26b5b19c 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -114,7 +114,7 @@ class LightningModule( "trainer", "fabric", "strict_loading", - "device_mesh" + "device_mesh", ] + _DeviceDtypeModuleMixin.__jit_unused_properties__ + HyperparametersMixin.__jit_unused_properties__ From 5e44dc11e7714d287c845821f8eb6e377d1e93df Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 17 May 2024 16:37:22 +0200 Subject: [PATCH 6/8] fix docs issue --- src/lightning/pytorch/core/module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index d653c26b5b19c..5a4f8d4e1bbb1 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -330,7 +330,7 @@ def loggers(self) -> Union[List[Logger], List[FabricLogger]]: @property def device_mesh(self) -> Optional["DeviceMesh"]: """Strategies like ``ModelParallelStrategy`` will create a device mesh that can be accessed in the - :meth:`configure_model` hook to parallelize the LightningModule.""" + :meth:`~lightning.pytorch.core.hooks.ModelHooks.configure_model` hook to parallelize the LightningModule.""" return self._device_mesh def _call_batch_hook(self, hook_name: str, *args: Any) -> Any: From 9b34c9923eda90a65a61cf5cf04c554e192961e6 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Fri, 17 May 2024 16:37:59 +0200 Subject: [PATCH 7/8] fix test execution --- .../tests_pytorch/strategies/test_model_parallel_integration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/strategies/test_model_parallel_integration.py b/tests/tests_pytorch/strategies/test_model_parallel_integration.py index 86a62a4e77dc6..bbac2a6078f9c 100644 --- a/tests/tests_pytorch/strategies/test_model_parallel_integration.py +++ b/tests/tests_pytorch/strategies/test_model_parallel_integration.py @@ -263,7 +263,7 @@ def training_step(self, batch): trainer.fit(model) -@RunIf(min_torch="2.3", min_cuda_gpus=2) +@RunIf(min_torch="2.3", min_cuda_gpus=2, standalone=True) def test_modules_without_parameters(tmp_path): """Test that TorchMetrics get moved to the device despite not having any parameters.""" From 8eebff970c879091ccded2961a7a4c1c59c1fdb0 Mon Sep 17 00:00:00 2001 From: Luca Antiga Date: Fri, 17 May 2024 18:15:44 -0400 Subject: [PATCH 8/8] Update src/lightning/pytorch/strategies/model_parallel.py --- src/lightning/pytorch/strategies/model_parallel.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/strategies/model_parallel.py b/src/lightning/pytorch/strategies/model_parallel.py index 58a9d2e682fdb..304b9bc04fc2d 100644 --- a/src/lightning/pytorch/strategies/model_parallel.py +++ b/src/lightning/pytorch/strategies/model_parallel.py @@ -56,7 +56,8 @@ class ModelParallelStrategy(ParallelStrategy): Currently supports up to 2D parallelism. Specifically, it supports the combination of Fully Sharded Data-Parallel 2 (FSDP2) with Tensor Parallelism (DTensor). These PyTorch APIs are currently still - experimental in PyTorch. Requires PyTorch 2.3 or newer. + experimental in PyTorch (see https://pytorch.org/docs/stable/distributed.tensor.parallel.html). + Requires PyTorch 2.3 or newer. Arguments: data_parallel_size: The number of devices within a data-parallel group. Defaults to ``"auto"``, which