From a7eb2df70b4bed10a44fee421c1d66efcab129e4 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Thu, 2 Nov 2023 18:13:27 -0400 Subject: [PATCH] fix issue where BF16+O1 nemo ckpt converted to FP32 mcore ckpt Signed-off-by: Chen Cui --- .../convert_nemo_gpt_to_mcore.py | 27 ++++++++++++------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py index 5917ddf12a79..011bc0d5ae23 100644 --- a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py +++ b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py @@ -168,18 +168,25 @@ def load_model(model, state_dict): def restore_model(nemo_file, cpu_only=False): dummy_trainer = Trainer(devices=1, accelerator='cpu', strategy=NLPDDPStrategy()) - if cpu_only: - map_location = torch.device('cpu') - model_config = MegatronGPTModel.restore_from( - nemo_file, trainer=dummy_trainer, return_config=True, map_location=map_location - ) - model_config.use_cpu_initialization = True - else: - model_config, map_location = None, None - return MegatronGPTModel.restore_from( + map_location = torch.device('cpu') if cpu_only else None + model_config = MegatronGPTModel.restore_from( + nemo_file, trainer=dummy_trainer, return_config=True, map_location=map_location + ) + model_config.use_cpu_initialization = cpu_only + + # To copy weights in the original precision, we have to turn on O2. + orig_megatron_amp_O2_value = model_config.megatron_amp_O2 + + if model_config.precision in ['bf16', 'bf16-mixed']: + model_config.megatron_amp_O2 = True + + model = MegatronGPTModel.restore_from( nemo_file, trainer=dummy_trainer, override_config_path=model_config, map_location=map_location ) + # restore O2 to the original value so mcore model has the same config + model.cfg.megatron_amp_O2 = orig_megatron_amp_O2_value + return model def convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True, cpu_only=False): if skip_if_output_exists and os.path.exists(output_nemo_file): @@ -214,6 +221,8 @@ def convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True, cpu_o mcore_model.cfg.use_cpu_initialization = False mcore_model.save_to(output_nemo_file) logging.info(f"✅ Done. Model saved to {output_nemo_file}") + del mcore_model + del nemo_model def run_sanity_checks(nemo_file, mcore_file, cpu_only=False):