Skip to content

Commit

Permalink
Merge branch 'main' into dgalvez/fix-tensor-devices
Browse files Browse the repository at this point in the history
  • Loading branch information
titu1994 authored May 30, 2024
2 parents 76ed854 + 2e39606 commit 0335ca4
Showing 1 changed file with 7 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -843,9 +843,11 @@ def training_step(self, dataloader_iter):

# when using sequence parallelism, the sequence parallel layernorm grads must be all-reduced
if self.cfg.get('tensor_model_parallel_size', 1) > 1 and self.cfg.get('sequence_parallel', False):
self.megatron_timer_start('allreduce_sequence_parallel_gradients', log_level=1)
self.allreduce_sequence_parallel_gradients()
self.megatron_timer_stop('allreduce_sequence_parallel_gradients')
# Mcore DistOpt handles this, so we don't have to
if not self.use_mcore_dist_optim:
self.megatron_timer_start('allreduce_sequence_parallel_gradients', log_level=1)
self.allreduce_sequence_parallel_gradients()
self.megatron_timer_stop('allreduce_sequence_parallel_gradients')

self.megatron_timer_start('gradient_allreduce', log_level=1)
if self.use_fsdp:
Expand Down Expand Up @@ -999,8 +1001,8 @@ def allreduce_fsdp_sharding_omitted_gradients(self):
"""All-reduce gradients of FSDP-sharding-omitted parameters in sharding domain (data-parallel domain)."""
assert isinstance(self.model, torch.nn.Module)
grads = []
for param in self.model.parameters():
if not isinstance(param, torch.distributed.fsdp.FlatParameter) and param.requires_grad:
for param in self.model._ignored_params:
if param.requires_grad and param.grad is not None:
grad = param.grad
grads.append(grad.data)
if len(grads) > 0:
Expand Down

0 comments on commit 0335ca4

Please sign in to comment.