Skip to content

Commit

Permalink
fix 2.4.1 test (#3612)
Browse files Browse the repository at this point in the history
Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
bigning and mvpatel2000 authored Sep 10, 2024
1 parent d8236db commit a9cd768
Showing 1 changed file with 49 additions and 45 deletions.
94 changes: 49 additions & 45 deletions composer/trainer/_patch_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -946,51 +946,7 @@ def unshard_with_sync(self):
if version.parse(torch.__version__) >= version.parse('2.4.0') and version.parse(
torch.__version__,
) < version.parse('2.4.1'):
# Save original FlatParamHandle.unshard to revert back to when dropping automicrobatching hooks
from torch.distributed.fsdp._flat_param import FlatParamHandle
original_unshard = FlatParamHandle.unshard

@no_type_check
def unshard_with_sync(self):
"""Run the unshard logic, but with a sync after a :meth:`_alloc_padded_unsharded_flat_param`.
This prevents deadlocks when some ranks OOM after the alloc call and others do not.
This is a patched method from pytorch, meant to be called when automicrobatching
turns on hooks in its search process for the optimal non-OOMing microbatch size.
This includes all-gathering the flat parameter
and switching to using the unsharded flat parameter. If the handle does
not need unsharding, then this only switches to using the unsharded
flat parameter. For ``NO_SHARD``, this is a no-op.
If FSDP is in :meth:`summon_full_params` and the handle uses parameter
mixed precision, then the parameter is forced to full precision.
"""
if not self.needs_unshard():
# Even when not needing an unshard, we should switch to using
# the unsharded flat parameter
unsharded_flat_param = (
self._get_padded_unsharded_flat_param()
if self.uses_sharded_strategy
else self.flat_param
)
self._use_unsharded_flat_param(unsharded_flat_param)
return
unsharded_flat_param = self._alloc_padded_unsharded_flat_param()

# Check if any other rank hit an OOM
found_cuda_oom_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True)

dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX')
found_cuda_oom = found_cuda_oom_tensor.item()
# Signal current rank is still in batch
all_ranks_finished_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True)

dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN')

if found_cuda_oom == 1:
raise RuntimeError('CUDA out of memory encountered on a different rank')
padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
self._use_unsharded_flat_param(padded_unsharded_flat_param)

# 2.4.0 only patch
# PyTorch issue: https://github.com/pytorch/pytorch/issues/133923
from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from typing import Mapping, Collection
Expand Down Expand Up @@ -1046,3 +1002,51 @@ def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:

for key, value in state_dict.items():
_traverse_obj((str(key),), value)

if version.parse(torch.__version__) >= version.parse('2.4.0') and version.parse(
torch.__version__,
) < version.parse('2.4.2'):
# Save original FlatParamHandle.unshard to revert back to when dropping automicrobatching hooks
from torch.distributed.fsdp._flat_param import FlatParamHandle
original_unshard = FlatParamHandle.unshard

@no_type_check
def unshard_with_sync(self):
"""Run the unshard logic, but with a sync after a :meth:`_alloc_padded_unsharded_flat_param`.
This prevents deadlocks when some ranks OOM after the alloc call and others do not.
This is a patched method from pytorch, meant to be called when automicrobatching
turns on hooks in its search process for the optimal non-OOMing microbatch size.
This includes all-gathering the flat parameter
and switching to using the unsharded flat parameter. If the handle does
not need unsharding, then this only switches to using the unsharded
flat parameter. For ``NO_SHARD``, this is a no-op.
If FSDP is in :meth:`summon_full_params` and the handle uses parameter
mixed precision, then the parameter is forced to full precision.
"""
if not self.needs_unshard():
# Even when not needing an unshard, we should switch to using
# the unsharded flat parameter
unsharded_flat_param = (
self._get_padded_unsharded_flat_param()
if self.uses_sharded_strategy
else self.flat_param
)
self._use_unsharded_flat_param(unsharded_flat_param)
return
unsharded_flat_param = self._alloc_padded_unsharded_flat_param()

# Check if any other rank hit an OOM
found_cuda_oom_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True)

dist.all_reduce(found_cuda_oom_tensor, reduce_operation='MAX')
found_cuda_oom = found_cuda_oom_tensor.item()
# Signal current rank is still in batch
all_ranks_finished_tensor = torch.tensor([0], dtype=torch.uint8).to(self.device, non_blocking=True)

dist.all_reduce(all_ranks_finished_tensor, reduce_operation='MIN')

if found_cuda_oom == 1:
raise RuntimeError('CUDA out of memory encountered on a different rank')
padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param)
self._use_unsharded_flat_param(padded_unsharded_flat_param)

0 comments on commit a9cd768

Please sign in to comment.