Skip to content

Commit

Permalink
fix issue where BF16+O1 nemo ckpt converted to FP32 mcore ckpt
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Cui <[email protected]>
  • Loading branch information
cuichenx committed Nov 2, 2023
1 parent c6b4e54 commit a7eb2df
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,18 +168,25 @@ def load_model(model, state_dict):

def restore_model(nemo_file, cpu_only=False):
dummy_trainer = Trainer(devices=1, accelerator='cpu', strategy=NLPDDPStrategy())
if cpu_only:
map_location = torch.device('cpu')
model_config = MegatronGPTModel.restore_from(
nemo_file, trainer=dummy_trainer, return_config=True, map_location=map_location
)
model_config.use_cpu_initialization = True
else:
model_config, map_location = None, None
return MegatronGPTModel.restore_from(
map_location = torch.device('cpu') if cpu_only else None
model_config = MegatronGPTModel.restore_from(
nemo_file, trainer=dummy_trainer, return_config=True, map_location=map_location
)
model_config.use_cpu_initialization = cpu_only

# To copy weights in the original precision, we have to turn on O2.
orig_megatron_amp_O2_value = model_config.megatron_amp_O2

if model_config.precision in ['bf16', 'bf16-mixed']:
model_config.megatron_amp_O2 = True

model = MegatronGPTModel.restore_from(
nemo_file, trainer=dummy_trainer, override_config_path=model_config, map_location=map_location
)

# restore O2 to the original value so mcore model has the same config
model.cfg.megatron_amp_O2 = orig_megatron_amp_O2_value
return model

def convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True, cpu_only=False):
if skip_if_output_exists and os.path.exists(output_nemo_file):
Expand Down Expand Up @@ -214,6 +221,8 @@ def convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True, cpu_o
mcore_model.cfg.use_cpu_initialization = False
mcore_model.save_to(output_nemo_file)
logging.info(f"✅ Done. Model saved to {output_nemo_file}")
del mcore_model
del nemo_model


def run_sanity_checks(nemo_file, mcore_file, cpu_only=False):
Expand Down

0 comments on commit a7eb2df

Please sign in to comment.