Skip to content

Commit

Permalink
Change defaults to make intergrations easier. (#2092)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2092

Change Semi-Sync default values to better reflect typical usage patterns.

Reviewed By: joshuadeng, spxuaaaaaa

Differential Revision: D58318941

fbshipit-source-id: df364eea35818bef911408fd39bedb5dc048e12c
  • Loading branch information
dstaay-fb authored and facebook-github-bot committed Jun 10, 2024
1 parent b3f569d commit 5d320c6
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 9 deletions.
19 changes: 14 additions & 5 deletions torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,20 @@ def runner(
TrainPipelineSemiSync,
PrefetchTrainPipelineSparseDist,
]:
pipeline = pipeline_clazz(
model=sharded_model,
optimizer=optimizer,
device=ctx.device,
)
if pipeline_clazz == TrainPipelineSemiSync:
# pyre-ignore [28]
pipeline = pipeline_clazz(
model=sharded_model,
optimizer=optimizer,
device=ctx.device,
start_batch=0,
)
else:
pipeline = pipeline_clazz(
model=sharded_model,
optimizer=optimizer,
device=ctx.device,
)
pipeline.progress(iter(bench_inputs))

def _func_to_benchmark(
Expand Down
8 changes: 4 additions & 4 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,7 +513,7 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None:
class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
"""
Novel method for RecSys model training by leveraging "Semi-Synchronous" training,
where the model is still synchorous but each batch prediction is calculated
where the model is still synchronous but each batch prediction is calculated
on parameters which were last updated B-2, instead of the batch prior (ie. B-1). This
allows the Embedding All-to-All from B to be fully overlapped with forward pass of B-1; dramatically
improving peak training performance.
Expand All @@ -527,7 +527,7 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]):
execute_all_batches (bool): executes remaining batches in pipeline after
exhausting dataloader iterator.
apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules.
start_batch (int): batch to begin semi-sync training.
start_batch (int): batch to begin semi-sync training. Typically small period of synchronous training reduces early stage NEX.
stash_gradients (bool): if True, will store gradients for each parameter to insure true "Semi-Sync"
training. If False, will update dense optimizer as soon as gradients available (naive "Semi-Sync)
"""
Expand All @@ -539,8 +539,8 @@ def __init__(
device: torch.device,
execute_all_batches: bool = True,
apply_jit: bool = False,
start_batch: int = 0,
stash_gradients: bool = True,
start_batch: int = 900,
stash_gradients: bool = False,
) -> None:
super().__init__(
model=model,
Expand Down

0 comments on commit 5d320c6

Please sign in to comment.