Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(2/n) Support 2D Parallelism - Distributed Checkpoints #19852

Merged
merged 6 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lightning/fabric/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added support for PyTorch 2.3 ([#19708](https://github.com/Lightning-AI/pytorch-lightning/pull/19708))

-
- Added `ModelParallelStrategy` to support 2D parallelism ([#19846](https://github.com/Lightning-AI/pytorch-lightning/pull/19846), [#19852](https://github.com/Lightning-AI/pytorch-lightning/pull/19852))


### Changed

Expand Down
38 changes: 8 additions & 30 deletions src/lightning/fabric/strategies/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ def clip_gradients_norm(
# the root must be wrapped
raise TypeError(
"Gradient clipping with FSDP is only possible if the module passed to"
f" `{self.__class__.__name__}.clip_gradients_norm` is wrapped in `FullyShardedDataParallel`."
f" `{type(self).__name__}.clip_gradients_norm` is wrapped in `FullyShardedDataParallel`."
f" Got: {module.__class__.__name__}."
)
self.precision.unscale_gradients(optimizer)
Expand Down Expand Up @@ -506,12 +506,7 @@ def load_checkpoint(
state: Optional[Union[Module, Optimizer, Dict[str, Union[Module, Optimizer, Any]]]] = None,
strict: bool = True,
) -> Dict[str, Any]:
"""Load the contents from a checkpoint and restore the state of the given objects.

The strategy currently only supports saving and loading sharded checkpoints which are stored in form of a
directory of multiple files rather than a single file.

"""
"""Load the contents from a checkpoint and restore the state of the given objects."""
if not state:
raise ValueError(
f"Got FSDPStrategy.load_checkpoint(..., state={state!r}) but a state with at least "
Expand All @@ -522,6 +517,8 @@ def load_checkpoint(
path = Path(self.broadcast(path))

if isinstance(state, Module):
from lightning.fabric.strategies.model_parallel import _load_raw_module_state_from_path

_load_raw_module_state_from_path(path, module=state, world_size=self.world_size, strict=strict)
return {}

Expand Down Expand Up @@ -592,6 +589,9 @@ def load_checkpoint(

if _is_full_checkpoint(path):
checkpoint = _lazy_load(path)

from lightning.fabric.strategies.model_parallel import _load_raw_module_state

_load_raw_module_state(checkpoint.pop(module_key), module=module, world_size=self.world_size, strict=strict)

if isinstance(state, Module):
Expand Down Expand Up @@ -755,7 +755,7 @@ def no_backward_sync(self, module: Module, enabled: bool) -> ContextManager:
# the root must be wrapped
raise TypeError(
"Blocking backward sync is only possible if the module passed to"
f" `{self.__class__.__name__}.no_backward_sync` is wrapped in `FullyShardedDataParallel`."
f" `{type(self).__name__}.no_backward_sync` is wrapped in `FullyShardedDataParallel`."
f" Got: {module.__class__.__name__}."
)
return module.no_sync()
Expand Down Expand Up @@ -848,28 +848,6 @@ def _has_fsdp_modules(module: object) -> TypeGuard[Module]:
return isinstance(module, Module) and any(isinstance(m, FullyShardedDataParallel) for m in module.modules())


def _load_raw_module_state_from_path(path: Path, module: Module, world_size: int, strict: bool = True) -> None:
"""Loads the state dict from a file path into the FSDP module."""
if not _is_full_checkpoint(path):
raise ValueError(
"Failed to load checkpoint directly into the model. The given path must be a single file containing the"
f" full state dict: {path}"
)
# Use `lazy_load` instead of `torch.load` here to avoid storing a copy of the full checkpoint per rank
_load_raw_module_state(state_dict=_lazy_load(path), module=module, world_size=world_size, strict=strict)


def _load_raw_module_state(state_dict: Dict[str, Any], module: Module, world_size: int, strict: bool = True) -> None:
"""Loads the state dict into the module by gathering all weights first and then and writing back to each shard."""
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP

if not isinstance(module, FSDP):
module.load_state_dict(state_dict, strict=strict)
else:
with _get_full_state_dict_context(module, world_size=world_size, rank0_only=False):
module.load_state_dict(state_dict, strict=strict)


def _move_torchmetrics_to_device(module: torch.nn.Module, device: torch.device) -> None:
# FSDP doesn't move modules without parameters (e.g. Metrics) to the device
# https://github.com/pytorch/pytorch/issues/113113
Expand Down
Loading
Loading