From 3653bf290c62e312f31ddc89df75bc18d0e163ad Mon Sep 17 00:00:00 2001 From: Xilun Wu <12968408+XilunWu@users.noreply.github.com> Date: Wed, 30 Oct 2024 22:34:56 -0700 Subject: [PATCH] [BE] replace the extra DeviceMesh _flatten with mesh access (#666) Stack from [ghstack](https://github.com/ezyang/ghstack) (oldest at bottom): * __->__ #666 **Summary** https://github.com/pytorch/pytorch/pull/138945 fixes DeviceMesh access on flattened mesh which are constructed from more than 2 meshes. Refer to the fix PR for details if interested. In #592 we avoided this issue by calling `_flatten` instead of direct accessing the flattened mesh. We want to turn back to mesh access which is more straightforward since the fix has been merged in PyTorch. --- torchtitan/parallelisms/parallelize_llama.py | 42 ++++++++++++-------- torchtitan/parallelisms/utils.py | 28 +++++++++++++ 2 files changed, 54 insertions(+), 16 deletions(-) create mode 100644 torchtitan/parallelisms/utils.py diff --git a/torchtitan/parallelisms/parallelize_llama.py b/torchtitan/parallelisms/parallelize_llama.py index 84666031a..2a66f4723 100644 --- a/torchtitan/parallelisms/parallelize_llama.py +++ b/torchtitan/parallelisms/parallelize_llama.py @@ -34,6 +34,7 @@ from torchtitan.config_manager import JobConfig, TORCH_DTYPE_MAP from torchtitan.logging import logger from torchtitan.parallelisms.parallel_dims import ParallelDims +from torchtitan.parallelisms.utils import check_if_feature_in_pytorch def parallelize_llama( @@ -79,22 +80,31 @@ def parallelize_llama( if ( parallel_dims.dp_shard_enabled ): # apply FSDP or HSDP, potentially with Context Parallel - - # TODO: instead of flattening the mesh twice, we could've done in a batter way: - # dp_mesh = world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] - # However, this leads to an error in `DeviceMesh.__get_item__` which I believe is - # a bug in DeviceMesh. We should fix it and then use the above line. - dp_mesh_dim_names = ( - ("dp_replicate", "dp_shard") - if parallel_dims.dp_replicate_enabled - else ("dp",) - ) - # note that mesh can only be flattened from the finest-grained mesh dimensions - dp_mesh = ( - world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp") - if parallel_dims.cp_enabled - else world_mesh[dp_mesh_dim_names] - ) + try: + dp_mesh = ( + world_mesh["dp_cp"] if parallel_dims.cp_enabled else world_mesh["dp"] + ) + except IndexError: + # note: this is a workaround of the above logic for old pytorch version + # where https://github.com/pytorch/pytorch/pull/138945 is not included + # throw a warning to encourage users to upgrade to a newer pytorch version + check_if_feature_in_pytorch( + "DeviceMesh flattening over 3D+ meshes", + "https://github.com/pytorch/pytorch/pull/138945", + "2.6.0.dev20241030", + ) + # TODO: remove this workaround once PyTorch 2.6 is released + dp_mesh_dim_names = ( + ("dp_replicate", "dp_shard") + if parallel_dims.dp_replicate_enabled + else ("dp",) + ) + # note that mesh can only be flattened from the finest-grained mesh dimensions + dp_mesh = ( + world_mesh[(*dp_mesh_dim_names, "cp")]._flatten("dp_cp") + if parallel_dims.cp_enabled + else world_mesh[dp_mesh_dim_names] + ) apply_fsdp( model, diff --git a/torchtitan/parallelisms/utils.py b/torchtitan/parallelisms/utils.py new file mode 100644 index 000000000..a84af7981 --- /dev/null +++ b/torchtitan/parallelisms/utils.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional + +import torch +from torchtitan.logging import logger + + +def check_if_feature_in_pytorch( + feature_name: str, + pull_request: str, + min_nightly_version: Optional[str] = None, +) -> None: + if "git" in torch.__version__: # pytorch is built from source + # notify users to check if the pull request is included in their pytorch + logger.warning( + "detected that the pytorch is built from source. Please make sure the PR " + f"({pull_request_link}) is included in pytorch for correct {feature_name}." + ) + elif min_nightly_version is not None and torch.__version__ < min_nightly_version: + logger.warning( + f"detected that the pytorch version {torch.__version__} is older than " + f"{min_nightly_version}. Please upgrade a newer version to include the " + f"change in ({pull_request_link}) for correct {feature_name}." + )