Skip to content

Commit

Permalink
Add reconfigure microbatch calculator before inference and update GBS…
Browse files Browse the repository at this point in the history
…, MBS for inference (#7763)

* Add reconfigure microbatch calculator before inference

Signed-off-by: Abhishree <[email protected]>

* Add missing import for reconfigure microbatch calculator

Signed-off-by: Abhishree <[email protected]>

* Add comment for reconfiguring microbatch calculator during inference

Signed-off-by: Abhishree Thittenamane <[email protected]>

* Update comment

Signed-off-by: Abhishree Thittenamane <[email protected]>

* Fix typo

Signed-off-by: Abhishree Thittenamane <[email protected]>

---------

Signed-off-by: Abhishree <[email protected]>
Signed-off-by: Abhishree Thittenamane <[email protected]>
  • Loading branch information
athitten authored Oct 20, 2023
1 parent 21d89d3 commit 37c7c50
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions tutorials/nlp/lora.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1372,6 +1372,8 @@
"source": [
"with open_dict(peft_model_cfg):\n",
" # update the model config of the trained model with params we want to set at inference time.\n",
" peft_model_cfg.global_batch_size = config_eval.model.global_batch_size\n",
" peft_model_cfg.micro_batch_size = config_eval.model.micro_batch_size\n",
" peft_model_cfg.precision = config_eval.trainer.precision\n",
" peft_model_cfg.data.test_ds = config_eval.model.data.test_ds\n",
" peft_model_cfg.activations_checkpoint_granularity = None\n",
Expand Down Expand Up @@ -1475,6 +1477,16 @@
}
],
"source": [
"from apex.transformer.pipeline_parallel.utils import _reconfigure_microbatch_calculator\n",
"# Reset the microbatch calculator with eval batch sizes, required while doing both training\n",
"# and inference\n",
"_reconfigure_microbatch_calculator(\n",
" rank=0,\n",
" rampup_batch_size=None,\n",
" global_batch_size=config_eval.model.global_batch_size,\n",
" micro_batch_size=config_eval.model.micro_batch_size,\n",
" data_parallel_size=1,\n",
")\n",
"save_restore_connector = PEFTSaveRestoreConnector(\n",
" peft_model_nemo_path=config_eval.model.peft.restore_from_path, peft_model_ckpt_path=None,\n",
")\n",
Expand Down

0 comments on commit 37c7c50

Please sign in to comment.