Skip to content

Commit

Permalink
Support Sync start and 'Naive' semi-sync
Browse files Browse the repository at this point in the history
Summary:
expand API to support:
  stash_gradients: bool, True -> dense optimizer will match sparse optimizer (B-2), False -> immediately apply dense optimizer, will be lower memory

  start_batch: int -> batch number to start semi-sync strategy

Initution to date (ymmv):

Ads is doing stash_gradients = False, but on MRS I saw NE regressions doing this.

Reviewed By: henrylhtsang

Differential Revision:
D57640178

Privacy Context Container: 1203980333745195

fbshipit-source-id: de00a81233e2f24af08b3e130a94a769504447ab
  • Loading branch information
dstaay-fb authored and facebook-github-bot committed May 28, 2024
1 parent 3ca8e8b commit b4f7649
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 17 deletions.
44 changes: 34 additions & 10 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,11 +472,12 @@ class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase):
@settings(max_examples=4, deadline=None)
# pyre-ignore[56]
@given(
start_batch=st.sampled_from([0, 6]),
stash_gradients=st.booleans(),
sharding_type=st.sampled_from(
[
ShardingType.TABLE_WISE.value,
ShardingType.ROW_WISE.value,
ShardingType.COLUMN_WISE.value,
]
),
kernel_type=st.sampled_from(
Expand All @@ -487,6 +488,8 @@ class EmbeddingTrainPipelineTest(TrainPipelineSparseDistTestBase):
)
def test_equal_to_non_pipelined(
self,
start_batch: int,
stash_gradients: bool,
sharding_type: str,
kernel_type: str,
) -> None:
Expand Down Expand Up @@ -527,29 +530,38 @@ def test_equal_to_non_pipelined(
optimizer=optim_pipelined,
device=self.device,
execute_all_batches=True,
start_batch=start_batch,
stash_gradients=stash_gradients,
)

prior_sparse_out = sharded_model._dmp_wrapped_module.sparse_forward(
data[0].to(self.device)
)
prior_batch = data[0].to(self.device)
prior_stashed_grads = None
batch_index = 0
sparse_out = None
for batch in data[1:]:
batch_index += 1
# Forward + backward w/o pipelining
batch = batch.to(self.device)

loss, pred = sharded_model._dmp_wrapped_module.dense_forward(
prior_batch, prior_sparse_out
)
sparse_out = sharded_model._dmp_wrapped_module.sparse_forward(batch)
# stash grads:
if batch_index - 1 >= start_batch:
sparse_out = sharded_model._dmp_wrapped_module.sparse_forward(batch)

loss.backward()
stashed_grads = []
for param in optim.param_groups[0]["params"]:
stashed_grads.append(
param.grad.clone() if param.grad is not None else None
)
param.grad = None

stashed_grads = None
if batch_index - 1 >= start_batch and stash_gradients:
stashed_grads = []
for param in optim.param_groups[0]["params"]:
stashed_grads.append(
param.grad.clone() if param.grad is not None else None
)
param.grad = None

if prior_stashed_grads is not None:
for param, stashed_grad in zip(
Expand All @@ -559,13 +571,25 @@ def test_equal_to_non_pipelined(
optim.step()
optim.zero_grad()

if batch_index - 1 < start_batch:
sparse_out = sharded_model._dmp_wrapped_module.sparse_forward(batch)

prior_stashed_grads = stashed_grads
prior_batch = batch
prior_sparse_out = sparse_out
# Forward + backward w/ pipelining
pred_pipeline = pipeline.progress(dataloader)

self.assertTrue(torch.equal(pred, pred_pipeline))
if batch_index >= start_batch:
self.assertTrue(
pipeline.is_semi_sync(), msg="pipeline is not semi_sync"
)
else:
self.assertFalse(pipeline.is_semi_sync(), msg="pipeline is semi_sync")
self.assertTrue(
torch.equal(pred, pred_pipeline),
msg=f"batch {batch_index} doesn't match",
)

# one more batch
pred_pipeline = pipeline.progress(dataloader)
Expand Down
38 changes: 31 additions & 7 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,9 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
execute_all_batches (bool): executes remaining batches in pipeline after
exhausting dataloader iterator.
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
start_batch (int): batch to begin semi-sync training.
stash_gradients (bool): if True, will store gradients for each parameter to insure true "Semi-Sync"
training. If False, will update dense optimizer as soon as gradients available (naive "Semi-Sync)
"""

def __init__(
Expand All @@ -470,6 +473,8 @@ def __init__(
device: torch.device,
execute_all_batches: bool = True,
apply_jit: bool = False,
start_batch: int = 0,
stash_gradients: bool = True,
) -> None:
super().__init__(
model=model,
Expand All @@ -479,6 +484,8 @@ def __init__(
apply_jit=apply_jit,
context_type=EmbeddingTrainPipelineContext,
)
self._start_batch = start_batch
self._stash_gradients = stash_gradients

# use two data streams to support two concurrent batches
self._embedding_odd_stream: Optional[torch.cuda.streams.Stream] = (
Expand Down Expand Up @@ -533,6 +540,22 @@ def fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
if not self.enqueue_batch(dataloader_iter):
return

def is_semi_sync(self) -> bool:
if len(self.batches) >= 1:
# pyre-ignore [58]
return self.contexts[0].index >= self._start_batch
return False

def _mlp_optimizer_step(self) -> None:
# special case: not all optimizers support optim.step() on null gradidents
if (
len(self.batches) >= 1
and self.contexts[0].index == self._start_batch
and self._stash_gradients
):
return
self._optimizer.step()

def progress(self, dataloader_iter: Iterator[In]) -> Out:
self.fill_pipeline(dataloader_iter)
if not self.batches:
Expand All @@ -549,7 +572,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
# batch i+3
self.enqueue_batch(dataloader_iter)

if len(self.batches) >= 2:
if len(self.batches) >= 2 and self.is_semi_sync():
# pyre-ignore [6]
self.start_embedding_lookup(self.batches[1], self.contexts[1])

Expand All @@ -562,24 +585,25 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
with record_function(
f"## optimizer {cast(int, self.contexts[0].index) - 1} ##"
):
self._grad_swap()
self._optimizer.step()
if self.is_semi_sync() and self._stash_gradients:
self._grad_swap()
self._mlp_optimizer_step()

with record_function(
f"## zero_grad {cast(int, self.contexts[0].index) - 1} ##"
):
self._optimizer.zero_grad()

if len(self.batches) >= 2 and not self.is_semi_sync():
# pyre-ignore [6]
self.start_embedding_lookup(self.batches[1], self.contexts[1])

self.dequeue_batch()
return output

def _mlp_forward(
self, batch: In, context: TrainPipelineContext
) -> Tuple[torch.Tensor, Out]:
if self._model.training:
with record_function(f"## zero_grad {context.index} ##"):
self._optimizer.zero_grad()

with record_function(f"## forward {context.index} ##"):
with torch.cuda.stream(self._overarch_stream):
_wait_for_event(batch, context.event)
Expand Down

0 comments on commit b4f7649

Please sign in to comment.