From 5d320c6a56514eb9a098c5d4d38142ccd2df113f Mon Sep 17 00:00:00 2001 From: Dennis van der Staay Date: Mon, 10 Jun 2024 15:50:51 -0700 Subject: [PATCH] Change defaults to make intergrations easier. (#2092) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2092 Change Semi-Sync default values to better reflect typical usage patterns. Reviewed By: joshuadeng, spxuaaaaaa Differential Revision: D58318941 fbshipit-source-id: df364eea35818bef911408fd39bedb5dc048e12c --- .../tests/pipeline_benchmarks.py | 19 ++++++++++++++----- .../train_pipeline/train_pipelines.py | 8 ++++---- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index 82e5bc209..f92cf82c9 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -275,11 +275,20 @@ def runner( TrainPipelineSemiSync, PrefetchTrainPipelineSparseDist, ]: - pipeline = pipeline_clazz( - model=sharded_model, - optimizer=optimizer, - device=ctx.device, - ) + if pipeline_clazz == TrainPipelineSemiSync: + # pyre-ignore [28] + pipeline = pipeline_clazz( + model=sharded_model, + optimizer=optimizer, + device=ctx.device, + start_batch=0, + ) + else: + pipeline = pipeline_clazz( + model=sharded_model, + optimizer=optimizer, + device=ctx.device, + ) pipeline.progress(iter(bench_inputs)) def _func_to_benchmark( diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index c4b49465d..90dcb0aaf 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -513,7 +513,7 @@ def _fill_pipeline(self, dataloader_iter: Iterator[In]) -> None: class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]): """ Novel method for RecSys model training by leveraging "Semi-Synchronous" training, - where the model is still synchorous but each batch prediction is calculated + where the model is still synchronous but each batch prediction is calculated on parameters which were last updated B-2, instead of the batch prior (ie. B-1). This allows the Embedding All-to-All from B to be fully overlapped with forward pass of B-1; dramatically improving peak training performance. @@ -527,7 +527,7 @@ class TrainPipelineSemiSync(TrainPipelineSparseDist[In, Out]): execute_all_batches (bool): executes remaining batches in pipeline after exhausting dataloader iterator. apply_jit (bool): apply torch.jit.script to non-pipelined (unsharded) modules. - start_batch (int): batch to begin semi-sync training. + start_batch (int): batch to begin semi-sync training. Typically small period of synchronous training reduces early stage NEX. stash_gradients (bool): if True, will store gradients for each parameter to insure true "Semi-Sync" training. If False, will update dense optimizer as soon as gradients available (naive "Semi-Sync) """ @@ -539,8 +539,8 @@ def __init__( device: torch.device, execute_all_batches: bool = True, apply_jit: bool = False, - start_batch: int = 0, - stash_gradients: bool = True, + start_batch: int = 900, + stash_gradients: bool = False, ) -> None: super().__init__( model=model,