Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

All unshard streams wait on computation every step #2823

Merged
merged 29 commits into from
Jan 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
976b8ac
patched torch
snarayan21 Dec 28, 2023
c449b52
fixed torch imports
snarayan21 Dec 28, 2023
a159b92
fixed torch imports
snarayan21 Dec 28, 2023
47ad05f
fixed torch imports
snarayan21 Dec 28, 2023
b37b503
Merge branch 'dev' of https://github.com/snarayan21/saaketh-composer …
snarayan21 Jan 2, 2024
aff2d16
patching through composer
snarayan21 Jan 2, 2024
9641da0
patching through composer
snarayan21 Jan 2, 2024
a1e5952
patching typingr
snarayan21 Jan 2, 2024
a5c94c0
comment added
snarayan21 Jan 2, 2024
2bd4a15
don't patch torch 2.1.0
snarayan21 Jan 2, 2024
ddb8749
patch torch 2.1.1 and 2.2.0
snarayan21 Jan 2, 2024
501e3da
linting fix
snarayan21 Jan 2, 2024
2c970f5
Merge branch 'dev' of https://github.com/snarayan21/saaketh-composer …
snarayan21 Jan 2, 2024
2b4a66a
waiting on computation stream from unshard stream
snarayan21 Jan 3, 2024
233225f
waiting on computation stream from unshard stream
snarayan21 Jan 3, 2024
333ee66
less waiting
snarayan21 Jan 3, 2024
9991ab7
no waiting
snarayan21 Jan 3, 2024
8635897
all unshard streams wait on computation stream now
snarayan21 Jan 3, 2024
3b8397f
merged main
snarayan21 Jan 4, 2024
04045bd
2.2.0 dev change
snarayan21 Jan 4, 2024
3517bb1
Merge branch 'dev' of https://github.com/snarayan21/saaketh-composer …
snarayan21 Jan 7, 2024
1113c2e
correct waiting on computation stream
snarayan21 Jan 7, 2024
af01b98
fsdp state typiung
snarayan21 Jan 7, 2024
2de4354
patching root pre forward
snarayan21 Jan 7, 2024
517e6a6
patching root pre forward
snarayan21 Jan 7, 2024
723175a
fsdp state typing
snarayan21 Jan 7, 2024
4f57ca5
patch forward
snarayan21 Jan 7, 2024
96239b8
correct waiting
snarayan21 Jan 8, 2024
3a66b5f
linting
snarayan21 Jan 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions composer/trainer/mosaic_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import torch
from packaging import version
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed.fsdp import FullyShardedDataParallel

from composer.trainer.mosaic_fsdp_utils import (_sharded_pre_load_state_dict_hook, build_metadata,
custom_auto_wrap_t1p13p1)
Expand Down Expand Up @@ -62,14 +61,25 @@ def patch_pytorch():
_runtime_utils._validate_and_get_hybrid_shard_state = lambda *args, **kwargs: None

# Better overlap communication and computation
from composer.trainer.mosaic_fsdp_utils import _share_state_and_init_handle_attrs_t2p1
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel

from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p1,
_wait_for_computation_stream, forward)
_runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p1
_runtime_utils._wait_for_computation_stream = _wait_for_computation_stream
_runtime_utils._root_pre_forward = _root_pre_forward
FullyShardedDataParallel.forward = forward

elif version.parse(torch.__version__) < version.parse('2.2.1'):
# Monkey patch for torch < 2.2.1 ie torch == 2.2.0

# Better overlap communication and computation
from torch.distributed.fsdp import _runtime_utils
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel

from composer.trainer.mosaic_fsdp_utils import _share_state_and_init_handle_attrs_t2p2
from composer.trainer.mosaic_fsdp_utils import (_root_pre_forward, _share_state_and_init_handle_attrs_t2p2,
_wait_for_computation_stream, forward)
_runtime_utils._share_state_and_init_handle_attrs = _share_state_and_init_handle_attrs_t2p2
_runtime_utils._wait_for_computation_stream = _wait_for_computation_stream
_runtime_utils._root_pre_forward = _root_pre_forward
FullyShardedDataParallel.forward = forward
167 changes: 149 additions & 18 deletions composer/trainer/mosaic_fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,153 @@ def fsdp_state_pg_ranks(state: '_FSDPState') -> Tuple[int, ...]:
return tuple(get_process_group_ranks(state.process_group))


def _wait_for_computation_stream(
computation_stream: torch.Stream,
root_state: '_FSDPState',
pre_unshard_stream: torch.Stream,
):
"""Unshard and pre-unshard streams wait for computation stream.

Has the unshard and pre-unshard streams wait for the computation stream.
For example, this should be called in the FSDP root's pre-forward to
respect optimizer step computation.
"""
# Tracing does not need to wait
if torch.distributed._functional_collectives.is_torchdynamo_compiling():
return
# Ensure all unshard streams wait for the computation stream.
unshard_streams = set()
for fsdp_state in root_state._all_fsdp_states:
unshard_streams.add(fsdp_state._unshard_stream)
for unshard_stream in unshard_streams:
unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]
# Having the pre-all-gather stream wait for the current stream even if we
# do not leverage the pre-all-gather stream is tolerable since this only
# runs once per iteration
pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined]


@no_type_check
def _root_pre_forward(
state: '_FSDPState',
module: nn.Module,
args,
kwargs,
) -> None:
"""Runs pre-forward logic specific to the root FSDP instance.

This should run before any individual module's pre-forward. This starts
with an attempt at lazy initialization (which only runs non-vacuously once).
Otherwise, if this is called on a non-root FSDP instance, then it returns
directly.
"""
from torch.distributed.fsdp._common_utils import _is_composable
from torch.distributed.fsdp._runtime_utils import (_cast_buffers_to_dtype_and_device,
_get_buffers_and_dtypes_for_computation, _lazy_init,
_reset_flat_param_grad_info_if_needed, _root_cast_forward_input)
from torch.distributed.utils import _p_assert, _to_kwargs
with torch.profiler.record_function('FullyShardedDataParallel._root_pre_forward'):
_lazy_init(state, module)
_p_assert(state._is_root is not None, 'Expects a root FSDP to have been set')
if not state._is_root:
# Always cast forward inputs in the root of this local FSDP unit for mixed
# precision, as this is where mixed precision could be configed.
# This is more useful for auto wrapping that is recommended in composable path.
# For manual wrapping, cast forward inputs on each local FSDP unit root will
# increase some overhead, so not turned on for model wrapper path right now where
# manual wrapping is more broadly used.
if _is_composable(state):
return _root_cast_forward_input(state, module, args, kwargs)
return args, kwargs

# We cast buffers back to full precision if we're forcing full precision. Disjointly, we check if buffers
# are in full precision and if we should cast them back to lower precision, which happens when
# exiting eval() mode.
handle = state._handle
if handle:
should_cast_buffers_to_full_prec = handle._force_full_precision
else:
should_cast_buffers_to_full_prec = True

if should_cast_buffers_to_full_prec:
_cast_buffers_to_dtype_and_device(
buffers=dict(module.named_buffers()).values(),
buffer_dtypes=list(state._buffer_name_to_orig_dtype.values()),
device=state.compute_device,
)
# This flag is only set when we cast buffers to full precision, to avoid the
# CPU overhead that can stem from retrieving all buffers and their types in the
# following else branch.
state._needs_buffer_dtype_restore_check = True
elif getattr(state, '_needs_buffer_dtype_restore_check', False):
# Check if buffers are in full precision and we need to cast them
# back down.
(
buffers,
buffer_dtypes_for_computation,
) = _get_buffers_and_dtypes_for_computation(state, module)
if len(buffers) > 0 and len(buffer_dtypes_for_computation) > 0:
if any(buffer.dtype != buffer_dtype_for_computation
for buffer, buffer_dtype_for_computation in zip(buffers, buffer_dtypes_for_computation)):
# Assume we have to cast everything if there is one mismatch
_cast_buffers_to_dtype_and_device(buffers, buffer_dtypes_for_computation, state.compute_device)
# We don't have to check this again until we cast buffers to full precision again.
state._needs_buffer_dtype_restore_check = False

if state.forward_prefetch:
handles = []
for fsdp_state in state._all_fsdp_states:
if fsdp_state._handle:
handles.append(fsdp_state._handle)
for handle in handles:
handle._needs_pre_forward_unshard = True
handle._prefetched = False

_wait_for_computation_stream(
state._device_handle.current_stream(),
state,
state._pre_unshard_stream,
)
_reset_flat_param_grad_info_if_needed(state._all_handles)

# Prepares the forward inputs by moving them to ``compute_device``
# TODO: Do not use the side stream for tensor copies for now; investigate
# the perf with/without it.
with torch.profiler.record_function('FullyShardedDataParallel._to_kwargs'):
args_tuple, kwargs_tuple = _to_kwargs(args, kwargs, state.compute_device, False)
args = args_tuple[0]
kwargs = kwargs_tuple[0]

return _root_cast_forward_input(state, module, args, kwargs)


def forward(self, *args: Any, **kwargs: Any) -> Any:
"""Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic."""
from torch.distributed.fsdp._runtime_utils import (_post_forward, _post_forward_reshard, _pre_forward,
_pre_forward_unshard)
from torch.distributed.utils import _p_assert
handle = self._handle
with torch.autograd.profiler.record_function('FullyShardedDataParallel.forward'):
args, kwargs = _root_pre_forward(self, self, args, kwargs)
unused = None
args, kwargs = _pre_forward(
self,
handle,
_pre_forward_unshard,
self._fsdp_wrapped_module,
args,
kwargs,
)
if handle:
_p_assert(
handle.flat_param.device == self.compute_device,
'Expected `FlatParameter` to be on the compute device '
f'{self.compute_device} but got {handle.flat_param.device}',
)
output = self._fsdp_wrapped_module(*args, **kwargs)
return _post_forward(self, handle, _post_forward_reshard, self, unused, output)


@no_type_check
def _share_state_and_init_handle_attrs_t2p1(
root_state: '_FSDPState',
Expand All @@ -801,8 +948,7 @@ def _share_state_and_init_handle_attrs_t2p1(
been modified to assign a different unshard stream to each process group.
"""
from torch.distributed.fsdp._runtime_utils import (HOMOGENEOUS_ATTR_NAMES, _init_device_mesh,
_validate_and_get_hybrid_shard_state,
_wait_for_computation_stream)
_validate_and_get_hybrid_shard_state)
from torch.distributed.utils import _p_assert

handle = root_state._handle
Expand Down Expand Up @@ -875,13 +1021,6 @@ def _share_state_and_init_handle_attrs_t2p1(
handle = fsdp_state._handle
if handle:
handle.init_flat_param_attributes()
# Ensure that all unshard streams wait on the default computation stream
for pg_unshard_stream in fsdp_pg_unshard_streams.values():
_wait_for_computation_stream(
root_state._device_handle.current_stream(),
pg_unshard_stream,
root_state._pre_unshard_stream,
)
for attr_name, attr_values in attr_name_to_values.items():
if len(attr_values) != 1:
raise ValueError(f'Expects one homogeneous value for {attr_name} but got {attr_values}')
Expand All @@ -899,8 +1038,7 @@ def _share_state_and_init_handle_attrs_t2p2(
done together to require a single loop over the states. This function has
been modified to assign a different unshard stream to each process group.
"""
from torch.distributed.fsdp._runtime_utils import (HOMOGENEOUS_ATTR_NAMES, _validate_and_get_hybrid_shard_state,
_wait_for_computation_stream)
from torch.distributed.fsdp._runtime_utils import HOMOGENEOUS_ATTR_NAMES, _validate_and_get_hybrid_shard_state
from torch.distributed.utils import _p_assert

handle = root_state._handle
Expand Down Expand Up @@ -973,13 +1111,6 @@ def _share_state_and_init_handle_attrs_t2p2(
handle = fsdp_state._handle
if handle:
handle.init_flat_param_attributes()
# Ensure that all unshard streams wait on the default computation stream
for pg_unshard_stream in fsdp_pg_unshard_streams.values():
_wait_for_computation_stream(
root_state._device_handle.current_stream(),
pg_unshard_stream,
root_state._pre_unshard_stream,
)
for attr_name, attr_values in attr_name_to_values.items():
if len(attr_values) != 1:
raise ValueError(f'Expects one homogeneous value for {attr_name} but got {attr_values}')
Loading