Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enabling the computation of validation loss and other metrics when using sequence parallelism #3183

Merged
merged 13 commits into from
Apr 10, 2024
2 changes: 1 addition & 1 deletion composer/core/data_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def __init__(
world_size = dist.get_world_size()
# Check for Distributed Sampler if not using IterableDataset on more than 1 GPU
if world_size > 1 and not isinstance(dataloader.dataset, torch.utils.data.IterableDataset):
is_sampler_distributed = dataloader.sampler and isinstance(dataloader.sampler, DistributedSampler)
is_sampler_distributed = isinstance(dataloader.sampler, DistributedSampler)
is_batch_sampler_distributed = dataloader.batch_sampler is not None and isinstance(
dataloader.batch_sampler,
DistributedSampler,
Expand Down
4 changes: 3 additions & 1 deletion composer/core/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def __init__(
self._eval_interval = None
self.eval_interval = eval_interval
self.auto_microbatching = _is_auto_microbatching(device_eval_microbatch_size)
if self.auto_microbatching and hasattr(self.dataloader, 'seq_parallel_world_size'):
raise ValueError('`device_eval_microbatch_size="auto"` is not compatible with sequence parallelism.')
self.device_eval_microbatch_size = _get_initial_device_eval_microbatch_size(
device_eval_microbatch_size,
self.auto_microbatching,
Expand Down Expand Up @@ -177,7 +179,7 @@ def _get_initial_device_eval_microbatch_size(
),
) from e
return batch_size
elif isinstance(device_eval_microbatch_size, int):
elif isinstance(device_eval_microbatch_size, Union[int, float]):
return device_eval_microbatch_size
else:
raise ValueError("device_eval_microbatch_size must be an int or ``'auto'``")
2 changes: 1 addition & 1 deletion composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def _validate_evaluator(evaluator: Evaluator, device: Device):
if hasattr(
evaluator.dataloader,
'seq_parallel_world_size',
) and evaluator.dataloader.seq_parallel_world_size > 1 and evaluator.dataloader.batch_size * evaluator.dataloader.seq_parallel_world_size != 1: # type: ignore
) and evaluator.dataloader.seq_parallel_world_size > 1 and evaluator.dataloader.device_eval_batch_size * evaluator.dataloader.seq_parallel_world_size != 1: # type: ignore
raise ValueError(
'Sequence parallelism requires a microbatch size of 1 distributed over the sequence parallel group.',
)
Expand Down
Loading