Skip to content

Commit

Permalink
Add option to change batch size if needed (#11268)
Browse files Browse the repository at this point in the history
* Add option to change batch size if needed

* Apply isort and black reformatting

Signed-off-by: BoxiangW <[email protected]>

---------

Signed-off-by: BoxiangW <[email protected]>
Co-authored-by: BoxiangW <[email protected]>
  • Loading branch information
BoxiangW and BoxiangW authored Jan 8, 2025
1 parent ee12d88 commit c0d7f2d
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 0 deletions.
2 changes: 2 additions & 0 deletions nemo/collections/nlp/modules/common/megatron/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def initialize_model_parallel_for_nemo(
micro_batch_size=micro_batch_size,
data_parallel_size=app_state.data_parallel_size,
rampup_batch_size=rampup_batch_size,
decrease_batch_size_if_needed=False,
)
else:
if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator):
Expand All @@ -201,6 +202,7 @@ def initialize_model_parallel_for_nemo(
micro_batch_size=micro_batch_size,
data_parallel_size=app_state.data_parallel_size,
rampup_batch_size=rampup_batch_size,
decrease_batch_size_if_needed=False,
)
else:
if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator):
Expand Down
2 changes: 2 additions & 0 deletions nemo/lightning/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def setup_microbatch_calculator(
micro_batch_size=micro_batch_size,
data_parallel_size=app_state.data_parallel_size,
rampup_batch_size=rampup_batch_size,
decrease_batch_size_if_needed=False,
)
else:
if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator):
Expand All @@ -121,6 +122,7 @@ def setup_microbatch_calculator(
micro_batch_size=micro_batch_size,
data_parallel_size=app_state.data_parallel_size,
rampup_batch_size=rampup_batch_size,
decrease_batch_size_if_needed=False,
)
else:
if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator):
Expand Down
2 changes: 2 additions & 0 deletions nemo/lightning/megatron_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ def initialize_model_parallel_for_nemo(
micro_batch_size=micro_batch_size,
data_parallel_size=app_state.data_parallel_size,
rampup_batch_size=rampup_batch_size,
decrease_batch_size_if_needed=False,
)
else:
if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator):
Expand All @@ -201,6 +202,7 @@ def initialize_model_parallel_for_nemo(
micro_batch_size=micro_batch_size,
data_parallel_size=app_state.data_parallel_size,
rampup_batch_size=rampup_batch_size,
decrease_batch_size_if_needed=False,
)
else:
if isinstance(_GLOBAL_NUM_MICROBATCHES_CALCULATOR, ConstantNumMicroBatchesCalculator):
Expand Down

0 comments on commit c0d7f2d

Please sign in to comment.