From 86e46c067b782d2baef2dd9d058290f1aae04a75 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Mon, 16 Oct 2023 12:38:08 -0700 Subject: [PATCH 01/13] add conversion script Signed-off-by: Chen Cui --- .../convert_nemo_gpt_to_mcore.py | 209 ++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py diff --git a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py new file mode 100644 index 000000000000..51bc2b6e4baf --- /dev/null +++ b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py @@ -0,0 +1,209 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from argparse import ArgumentParser +from collections import OrderedDict + +import torch +from omegaconf import OmegaConf +from pytorch_lightning.trainer.trainer import Trainer +from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy +from nemo.utils import logging, AppState + +r""" +Script to convert a legacy (non-mcore path) nemo checkpoint into mcore-path checkpoint for GPT models. + +*Important* Before running this script, please first +1) convert your legacy checkpoint to TP1 PP1 format: + python examples/nlp/language_modeling/megatron_change_num_partitions.py \ + \ + --target_tensor_model_parallel_size=1 \ + --target_pipeline_model_parallel_size=1 +2) extract your checkpoint to a folder with + tar -xvf your_ckpt.nemo + +Then, run this conversion script: +python convert_nemo_gpt_to_mcore.py \ + --in-file \ + --out-file +""" + + +def get_args(): + parser = ArgumentParser() + parser.add_argument("--in-file", type=str, default=None, required=True, help="Path to extracted, TP1 PP1 NeMo GPT checkpoint.",) + parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to output mcore weights file (ends in .nemo).") + args = parser.parse_args() + return args + +def get_mcore_model_from_nemo_ckpt(nemo_restore_from_path): + model_cfg = MegatronGPTModel.restore_from(nemo_restore_from_path, return_config=True) + model_cfg.tokenizer.vocab_file = None + model_cfg.tokenizer.merge_file = None + model_cfg.mcore_gpt = True + + logging.info("*** initializing mcore model with the following config") + logging.info(OmegaConf.to_yaml(model_cfg)) + trainer = Trainer(devices=1, accelerator='cpu', strategy=NLPDDPStrategy()) + + app_state = AppState() + if os.path.isdir(nemo_restore_from_path): + app_state.nemo_file_folder = nemo_restore_from_path + else: + logging.warning("`nemo_file_folder` is NOT set because checkpoint is not pre-extracted. Subsequent operations may fail.") + mcore_model = MegatronGPTModel(model_cfg, trainer=trainer) + return mcore_model + +def print_mcore_parameter_names(restore_from_path): + mcore_model = get_mcore_model_from_nemo_ckpt(restore_from_path) + + print("*********") + print('\n'.join(sorted([k+'###'+str(v.shape) for k, v in mcore_model.named_parameters()]))) + print("*********") + + +def build_key_mapping(nemo_cfg, use_O2_prefix=None): + num_layers = nemo_cfg.num_layers + has_bias = nemo_cfg.get("bias", True) + if use_O2_prefix is None: + use_O2_prefix = nemo_cfg.get('megatron_amp_O2', False) + model_str = 'model.module' if use_O2_prefix else 'model' + + # For GPT there is a 1:1 mapping of keys + mcore_to_nemo_mapping = { + f"{model_str}.embedding.word_embeddings.weight": "model.language_model.embedding.word_embeddings.weight", + f"{model_str}.decoder.final_layernorm.bias": "model.language_model.encoder.final_layernorm.bias", + f"{model_str}.decoder.final_layernorm.weight": "model.language_model.encoder.final_layernorm.weight", + } + if not nemo_cfg.get("share_embeddings_and_output_weights", True): + mcore_to_nemo_mapping[f"{model_str}.output_layer.weight"] = "model.language_model.output_layer.weight" + + if nemo_cfg.get("position_embedding_type", 'learned_absolute') == 'rope': + mcore_to_nemo_mapping[f"{model_str}.rotary_pos_emb.inv_freq"] = "model.language_model.rotary_pos_emb.inv_freq" + else: + mcore_to_nemo_mapping[f"{model_str}.embedding.position_embeddings.weight"] = "model.language_model.embedding.position_embeddings.weight" + + nemo_prefix = "model.language_model.encoder.layers" + mcore_prefix = f"{model_str}.decoder.layers" + for i in range(num_layers): + for wb in ('weight', 'bias') if has_bias else ('weight',): + mcore_to_nemo_mapping.update({ + f"{mcore_prefix}.{i}.mlp.linear_fc2.{wb}": f"{nemo_prefix}.{i}.mlp.dense_4h_to_h.{wb}", + f"{mcore_prefix}.{i}.mlp.linear_fc1.{wb}": f"{nemo_prefix}.{i}.mlp.dense_h_to_4h.{wb}", + f"{mcore_prefix}.{i}.self_attention.linear_proj.{wb}": f"{nemo_prefix}.{i}.self_attention.dense.{wb}", + f"{mcore_prefix}.{i}.self_attention.linear_qkv.{wb}": f"{nemo_prefix}.{i}.self_attention.query_key_value.{wb}", + }) + # layernorm layers always have bias! + for wb in ('weight', 'bias'): + mcore_to_nemo_mapping.update({ + f"{mcore_prefix}.{i}.self_attention.linear_qkv.layer_norm_{wb}": f"{nemo_prefix}.{i}.input_layernorm.{wb}", + f"{mcore_prefix}.{i}.mlp.linear_fc1.layer_norm_{wb}": f"{nemo_prefix}.{i}.post_attention_layernorm.{wb}", + }) + + return mcore_to_nemo_mapping + + +def load_model(model, state_dict): + # try: + for name, module in model.named_parameters(): + if name in state_dict: + module.data = state_dict.pop(name) + else: + raise RuntimeError(f"Unexpected key: {name} not in state_dict but in model.") + + + for name, buffer in model.named_buffers(): + if name in state_dict: + buffer.data = state_dict.pop(name) + + if len(state_dict.keys()) != 0: + raise RuntimeError(f"Additional keys: {state_dict.keys()} in state_dict but not in model.") + + return model + + +def convert(input_ckpt_file, output_ckpt_file, skip_if_output_exists=True): + if skip_if_output_exists and os.path.exists(output_ckpt_file): + logging.info(f"Output file already exists ({output_ckpt_file}), skipping conversion...") + return + dummy_trainer = Trainer(devices=1, accelerator='cpu') + + nemo_model = MegatronGPTModel.restore_from(input_ckpt_file, trainer=dummy_trainer) + nemo_tokenizer_model = nemo_model.cfg.tokenizer.model + nemo_state_dict = nemo_model.state_dict() + mcore_state_dict = OrderedDict() + for mcore_param, nemo_param in build_key_mapping(nemo_model.cfg).items(): + mcore_state_dict[mcore_param] = nemo_state_dict[nemo_param] + + mcore_model = get_mcore_model_from_nemo_ckpt(input_ckpt_file) + mcore_model = load_model(mcore_model, mcore_state_dict) + + if nemo_model.cfg.tokenizer.model is not None: + logging.info("registering artifact: tokenizer.model = " + nemo_tokenizer_model) + mcore_model.register_artifact("tokenizer.model", nemo_tokenizer_model) + + mcore_model.save_to(output_ckpt_file) + logging.info(f"Done. Model saved to {output_ckpt_file}") + + +def run_sanity_checks(nemo_ckpt_file, mcore_ckpt_file): + cfg = OmegaConf.load( + os.path.join(os.path.dirname(__file__), '../../examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_tuning_config.yaml') + ) + + cfg.trainer.precision = 'bf16' # change me + dtype = torch.bfloat16 + trainer = MegatronTrainerBuilder(cfg).create_trainer() + nemo_model = MegatronGPTModel.restore_from(nemo_ckpt_file, trainer=trainer).eval().to(dtype) + mcore_model = MegatronGPTModel.restore_from(mcore_ckpt_file, trainer=trainer).eval().to(dtype) + + logging.debug("*** Mcore model restored config") + logging.debug(OmegaConf.to_yaml(mcore_model.cfg)) + + nemo_summary = nemo_model.summarize() + mcore_summary = mcore_model.summarize() + + logging.info("Sanity checks:") + + # check num weights match + assert nemo_summary.total_parameters == mcore_summary.total_parameters, "❌ total parameters do not match" + assert nemo_summary.model_size == mcore_summary.model_size, "❌ model sizes do not match" + logging.info("✅ Number of weights match") + + # check weights match + mcore_state_dict = mcore_model.state_dict() + nemo_state_dict = nemo_model.state_dict() + for mcore_param, nemo_param in build_key_mapping(nemo_model.cfg, use_O2_prefix=False).items(): + try: + assert torch.allclose(mcore_state_dict[mcore_param], nemo_state_dict[nemo_param]), f"❌ parameter {mcore_param} does not match" + except KeyError: + buffers = [k for k, v in mcore_model.named_buffers()] + assert mcore_param in buffers or mcore_param.replace('model.', 'model.module.', 1) in buffers, \ + f"❌ parameter {mcore_param} is not found in the state dict or named_buffers()" + logging.info("✅ Weights match") + + +if __name__ == '__main__': + args = get_args() + + input_ckpt = args.in_file + output_ckpt = args.out_file + os.makedirs(os.path.dirname(output_ckpt), exist_ok=True) + convert(input_ckpt, output_ckpt, skip_if_output_exists=True) + torch.cuda.empty_cache() + run_sanity_checks(input_ckpt, output_ckpt) + From 49547805f4655659bc839090166b85f265b545c8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 16 Oct 2023 19:40:53 +0000 Subject: [PATCH 02/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../convert_nemo_gpt_to_mcore.py | 63 ++++++++++++------- 1 file changed, 41 insertions(+), 22 deletions(-) diff --git a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py index 51bc2b6e4baf..8c9387c20408 100644 --- a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py +++ b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py @@ -19,10 +19,11 @@ import torch from omegaconf import OmegaConf from pytorch_lightning.trainer.trainer import Trainer + from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy -from nemo.utils import logging, AppState +from nemo.utils import AppState, logging r""" Script to convert a legacy (non-mcore path) nemo checkpoint into mcore-path checkpoint for GPT models. @@ -45,11 +46,16 @@ def get_args(): parser = ArgumentParser() - parser.add_argument("--in-file", type=str, default=None, required=True, help="Path to extracted, TP1 PP1 NeMo GPT checkpoint.",) - parser.add_argument("--out-file", type=str, default=None, required=True, help="Path to output mcore weights file (ends in .nemo).") + parser.add_argument( + "--in-file", type=str, default=None, required=True, help="Path to extracted, TP1 PP1 NeMo GPT checkpoint.", + ) + parser.add_argument( + "--out-file", type=str, default=None, required=True, help="Path to output mcore weights file (ends in .nemo)." + ) args = parser.parse_args() return args + def get_mcore_model_from_nemo_ckpt(nemo_restore_from_path): model_cfg = MegatronGPTModel.restore_from(nemo_restore_from_path, return_config=True) model_cfg.tokenizer.vocab_file = None @@ -64,15 +70,18 @@ def get_mcore_model_from_nemo_ckpt(nemo_restore_from_path): if os.path.isdir(nemo_restore_from_path): app_state.nemo_file_folder = nemo_restore_from_path else: - logging.warning("`nemo_file_folder` is NOT set because checkpoint is not pre-extracted. Subsequent operations may fail.") + logging.warning( + "`nemo_file_folder` is NOT set because checkpoint is not pre-extracted. Subsequent operations may fail." + ) mcore_model = MegatronGPTModel(model_cfg, trainer=trainer) return mcore_model + def print_mcore_parameter_names(restore_from_path): mcore_model = get_mcore_model_from_nemo_ckpt(restore_from_path) print("*********") - print('\n'.join(sorted([k+'###'+str(v.shape) for k, v in mcore_model.named_parameters()]))) + print('\n'.join(sorted([k + '###' + str(v.shape) for k, v in mcore_model.named_parameters()]))) print("*********") @@ -95,24 +104,30 @@ def build_key_mapping(nemo_cfg, use_O2_prefix=None): if nemo_cfg.get("position_embedding_type", 'learned_absolute') == 'rope': mcore_to_nemo_mapping[f"{model_str}.rotary_pos_emb.inv_freq"] = "model.language_model.rotary_pos_emb.inv_freq" else: - mcore_to_nemo_mapping[f"{model_str}.embedding.position_embeddings.weight"] = "model.language_model.embedding.position_embeddings.weight" + mcore_to_nemo_mapping[ + f"{model_str}.embedding.position_embeddings.weight" + ] = "model.language_model.embedding.position_embeddings.weight" nemo_prefix = "model.language_model.encoder.layers" mcore_prefix = f"{model_str}.decoder.layers" for i in range(num_layers): for wb in ('weight', 'bias') if has_bias else ('weight',): - mcore_to_nemo_mapping.update({ - f"{mcore_prefix}.{i}.mlp.linear_fc2.{wb}": f"{nemo_prefix}.{i}.mlp.dense_4h_to_h.{wb}", - f"{mcore_prefix}.{i}.mlp.linear_fc1.{wb}": f"{nemo_prefix}.{i}.mlp.dense_h_to_4h.{wb}", - f"{mcore_prefix}.{i}.self_attention.linear_proj.{wb}": f"{nemo_prefix}.{i}.self_attention.dense.{wb}", - f"{mcore_prefix}.{i}.self_attention.linear_qkv.{wb}": f"{nemo_prefix}.{i}.self_attention.query_key_value.{wb}", - }) + mcore_to_nemo_mapping.update( + { + f"{mcore_prefix}.{i}.mlp.linear_fc2.{wb}": f"{nemo_prefix}.{i}.mlp.dense_4h_to_h.{wb}", + f"{mcore_prefix}.{i}.mlp.linear_fc1.{wb}": f"{nemo_prefix}.{i}.mlp.dense_h_to_4h.{wb}", + f"{mcore_prefix}.{i}.self_attention.linear_proj.{wb}": f"{nemo_prefix}.{i}.self_attention.dense.{wb}", + f"{mcore_prefix}.{i}.self_attention.linear_qkv.{wb}": f"{nemo_prefix}.{i}.self_attention.query_key_value.{wb}", + } + ) # layernorm layers always have bias! for wb in ('weight', 'bias'): - mcore_to_nemo_mapping.update({ - f"{mcore_prefix}.{i}.self_attention.linear_qkv.layer_norm_{wb}": f"{nemo_prefix}.{i}.input_layernorm.{wb}", - f"{mcore_prefix}.{i}.mlp.linear_fc1.layer_norm_{wb}": f"{nemo_prefix}.{i}.post_attention_layernorm.{wb}", - }) + mcore_to_nemo_mapping.update( + { + f"{mcore_prefix}.{i}.self_attention.linear_qkv.layer_norm_{wb}": f"{nemo_prefix}.{i}.input_layernorm.{wb}", + f"{mcore_prefix}.{i}.mlp.linear_fc1.layer_norm_{wb}": f"{nemo_prefix}.{i}.post_attention_layernorm.{wb}", + } + ) return mcore_to_nemo_mapping @@ -125,7 +140,6 @@ def load_model(model, state_dict): else: raise RuntimeError(f"Unexpected key: {name} not in state_dict but in model.") - for name, buffer in model.named_buffers(): if name in state_dict: buffer.data = state_dict.pop(name) @@ -162,7 +176,10 @@ def convert(input_ckpt_file, output_ckpt_file, skip_if_output_exists=True): def run_sanity_checks(nemo_ckpt_file, mcore_ckpt_file): cfg = OmegaConf.load( - os.path.join(os.path.dirname(__file__), '../../examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_tuning_config.yaml') + os.path.join( + os.path.dirname(__file__), + '../../examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_tuning_config.yaml', + ) ) cfg.trainer.precision = 'bf16' # change me @@ -189,11 +206,14 @@ def run_sanity_checks(nemo_ckpt_file, mcore_ckpt_file): nemo_state_dict = nemo_model.state_dict() for mcore_param, nemo_param in build_key_mapping(nemo_model.cfg, use_O2_prefix=False).items(): try: - assert torch.allclose(mcore_state_dict[mcore_param], nemo_state_dict[nemo_param]), f"❌ parameter {mcore_param} does not match" + assert torch.allclose( + mcore_state_dict[mcore_param], nemo_state_dict[nemo_param] + ), f"❌ parameter {mcore_param} does not match" except KeyError: buffers = [k for k, v in mcore_model.named_buffers()] - assert mcore_param in buffers or mcore_param.replace('model.', 'model.module.', 1) in buffers, \ - f"❌ parameter {mcore_param} is not found in the state dict or named_buffers()" + assert ( + mcore_param in buffers or mcore_param.replace('model.', 'model.module.', 1) in buffers + ), f"❌ parameter {mcore_param} is not found in the state dict or named_buffers()" logging.info("✅ Weights match") @@ -206,4 +226,3 @@ def run_sanity_checks(nemo_ckpt_file, mcore_ckpt_file): convert(input_ckpt, output_ckpt, skip_if_output_exists=True) torch.cuda.empty_cache() run_sanity_checks(input_ckpt, output_ckpt) - From 68ca6add333b98563e4f0f4eab9d6bc92b420291 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Fri, 20 Oct 2023 09:41:44 -0700 Subject: [PATCH 03/13] remove references to 'ckpt' Signed-off-by: Chen Cui --- .../convert_nemo_gpt_to_mcore.py | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py index 8c9387c20408..7119f0f8de10 100644 --- a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py +++ b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py @@ -34,8 +34,8 @@ \ --target_tensor_model_parallel_size=1 \ --target_pipeline_model_parallel_size=1 -2) extract your checkpoint to a folder with - tar -xvf your_ckpt.nemo +2) extract your nemo file to a folder with + tar -xvf filename.nemo Then, run this conversion script: python convert_nemo_gpt_to_mcore.py \ @@ -56,7 +56,7 @@ def get_args(): return args -def get_mcore_model_from_nemo_ckpt(nemo_restore_from_path): +def get_mcore_model_from_nemo_file(nemo_restore_from_path): model_cfg = MegatronGPTModel.restore_from(nemo_restore_from_path, return_config=True) model_cfg.tokenizer.vocab_file = None model_cfg.tokenizer.merge_file = None @@ -78,7 +78,7 @@ def get_mcore_model_from_nemo_ckpt(nemo_restore_from_path): def print_mcore_parameter_names(restore_from_path): - mcore_model = get_mcore_model_from_nemo_ckpt(restore_from_path) + mcore_model = get_mcore_model_from_nemo_file(restore_from_path) print("*********") print('\n'.join(sorted([k + '###' + str(v.shape) for k, v in mcore_model.named_parameters()]))) @@ -150,31 +150,31 @@ def load_model(model, state_dict): return model -def convert(input_ckpt_file, output_ckpt_file, skip_if_output_exists=True): - if skip_if_output_exists and os.path.exists(output_ckpt_file): - logging.info(f"Output file already exists ({output_ckpt_file}), skipping conversion...") +def convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True): + if skip_if_output_exists and os.path.exists(output_nemo_file): + logging.info(f"Output file already exists ({output_nemo_file}), skipping conversion...") return dummy_trainer = Trainer(devices=1, accelerator='cpu') - nemo_model = MegatronGPTModel.restore_from(input_ckpt_file, trainer=dummy_trainer) + nemo_model = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer) nemo_tokenizer_model = nemo_model.cfg.tokenizer.model nemo_state_dict = nemo_model.state_dict() mcore_state_dict = OrderedDict() for mcore_param, nemo_param in build_key_mapping(nemo_model.cfg).items(): mcore_state_dict[mcore_param] = nemo_state_dict[nemo_param] - mcore_model = get_mcore_model_from_nemo_ckpt(input_ckpt_file) + mcore_model = get_mcore_model_from_nemo_file(input_nemo_file) mcore_model = load_model(mcore_model, mcore_state_dict) if nemo_model.cfg.tokenizer.model is not None: logging.info("registering artifact: tokenizer.model = " + nemo_tokenizer_model) mcore_model.register_artifact("tokenizer.model", nemo_tokenizer_model) - mcore_model.save_to(output_ckpt_file) - logging.info(f"Done. Model saved to {output_ckpt_file}") + mcore_model.save_to(output_nemo_file) + logging.info(f"Done. Model saved to {output_nemo_file}") -def run_sanity_checks(nemo_ckpt_file, mcore_ckpt_file): +def run_sanity_checks(nemo_file, mcore_file): cfg = OmegaConf.load( os.path.join( os.path.dirname(__file__), @@ -185,8 +185,8 @@ def run_sanity_checks(nemo_ckpt_file, mcore_ckpt_file): cfg.trainer.precision = 'bf16' # change me dtype = torch.bfloat16 trainer = MegatronTrainerBuilder(cfg).create_trainer() - nemo_model = MegatronGPTModel.restore_from(nemo_ckpt_file, trainer=trainer).eval().to(dtype) - mcore_model = MegatronGPTModel.restore_from(mcore_ckpt_file, trainer=trainer).eval().to(dtype) + nemo_model = MegatronGPTModel.restore_from(nemo_file, trainer=trainer).eval().to(dtype) + mcore_model = MegatronGPTModel.restore_from(mcore_file, trainer=trainer).eval().to(dtype) logging.debug("*** Mcore model restored config") logging.debug(OmegaConf.to_yaml(mcore_model.cfg)) @@ -220,9 +220,9 @@ def run_sanity_checks(nemo_ckpt_file, mcore_ckpt_file): if __name__ == '__main__': args = get_args() - input_ckpt = args.in_file - output_ckpt = args.out_file - os.makedirs(os.path.dirname(output_ckpt), exist_ok=True) - convert(input_ckpt, output_ckpt, skip_if_output_exists=True) + input_nemo_file = args.in_file + output_nemo_file = args.out_file + os.makedirs(os.path.dirname(output_nemo_file), exist_ok=True) + convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True) torch.cuda.empty_cache() - run_sanity_checks(input_ckpt, output_ckpt) + run_sanity_checks(input_nemo_file, output_nemo_file) From 637b72d82c9421a7b9b8939c21b40c22ff641e40 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Fri, 20 Oct 2023 13:47:48 -0700 Subject: [PATCH 04/13] add one more sanity check to make sure there is no unexpected keys in state dict Signed-off-by: Chen Cui --- .../convert_nemo_gpt_to_mcore.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py index 7119f0f8de10..7e4e21545d1a 100644 --- a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py +++ b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py @@ -85,12 +85,10 @@ def print_mcore_parameter_names(restore_from_path): print("*********") -def build_key_mapping(nemo_cfg, use_O2_prefix=None): +def build_key_mapping(nemo_cfg): num_layers = nemo_cfg.num_layers has_bias = nemo_cfg.get("bias", True) - if use_O2_prefix is None: - use_O2_prefix = nemo_cfg.get('megatron_amp_O2', False) - model_str = 'model.module' if use_O2_prefix else 'model' + model_str = 'model.module' if nemo_cfg.get('megatron_amp_O2', False) else 'model' # For GPT there is a 1:1 mapping of keys mcore_to_nemo_mapping = { @@ -204,18 +202,26 @@ def run_sanity_checks(nemo_file, mcore_file): # check weights match mcore_state_dict = mcore_model.state_dict() nemo_state_dict = nemo_model.state_dict() - for mcore_param, nemo_param in build_key_mapping(nemo_model.cfg, use_O2_prefix=False).items(): + nemo_model.cfg.megatron_amp_O2 = False # we want build_key_mapping in the next line to not use O2 prefix + for mcore_param, nemo_param in build_key_mapping(nemo_model.cfg).items(): try: assert torch.allclose( - mcore_state_dict[mcore_param], nemo_state_dict[nemo_param] + mcore_state_dict.pop(mcore_param), nemo_state_dict.pop(nemo_param) ), f"❌ parameter {mcore_param} does not match" except KeyError: buffers = [k for k, v in mcore_model.named_buffers()] assert ( mcore_param in buffers or mcore_param.replace('model.', 'model.module.', 1) in buffers ), f"❌ parameter {mcore_param} is not found in the state dict or named_buffers()" + nemo_state_dict.pop(nemo_param) + logging.info("✅ Weights match") + # check for unexpected weights in state dict + assert len(nemo_state_dict)==0, f"❌ unexpected items in nemo_state_dict: {nemo_state_dict}" + assert len([k for k in mcore_state_dict if not k.endswith('_extra_state')])==0, f"❌ unexpected items in mcore_state_dict: {mcore_state_dict}" + logging.info("✅ No unexpected weights in state dicts") + if __name__ == '__main__': args = get_args() From 790fc677de925c50ebe0af1015dd5af4cd64d68f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Oct 2023 21:59:17 +0000 Subject: [PATCH 05/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py index 7e4e21545d1a..a67149cd421f 100644 --- a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py +++ b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py @@ -218,8 +218,10 @@ def run_sanity_checks(nemo_file, mcore_file): logging.info("✅ Weights match") # check for unexpected weights in state dict - assert len(nemo_state_dict)==0, f"❌ unexpected items in nemo_state_dict: {nemo_state_dict}" - assert len([k for k in mcore_state_dict if not k.endswith('_extra_state')])==0, f"❌ unexpected items in mcore_state_dict: {mcore_state_dict}" + assert len(nemo_state_dict) == 0, f"❌ unexpected items in nemo_state_dict: {nemo_state_dict}" + assert ( + len([k for k in mcore_state_dict if not k.endswith('_extra_state')]) == 0 + ), f"❌ unexpected items in mcore_state_dict: {mcore_state_dict}" logging.info("✅ No unexpected weights in state dicts") From 5a2a330b0d49f062c81e0a080b56ad963ea03dbd Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Fri, 20 Oct 2023 15:57:49 -0700 Subject: [PATCH 06/13] make cpu loading work Signed-off-by: Chen Cui --- .../convert_nemo_gpt_to_mcore.py | 35 ++++++++++++++----- 1 file changed, 26 insertions(+), 9 deletions(-) diff --git a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py index 7e4e21545d1a..9f4a53d86a6d 100644 --- a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py +++ b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py @@ -52,6 +52,10 @@ def get_args(): parser.add_argument( "--out-file", type=str, default=None, required=True, help="Path to output mcore weights file (ends in .nemo)." ) + parser.add_argument( + "--cpu-only", action="store_true", help="Load model in cpu only. Useful if the model cannot fit in GPU memory, " + "but this option makes the conversion script significantly slower." + ) args = parser.parse_args() return args @@ -61,6 +65,7 @@ def get_mcore_model_from_nemo_file(nemo_restore_from_path): model_cfg.tokenizer.vocab_file = None model_cfg.tokenizer.merge_file = None model_cfg.mcore_gpt = True + model_cfg.use_cpu_initialization = True logging.info("*** initializing mcore model with the following config") logging.info(OmegaConf.to_yaml(model_cfg)) @@ -147,14 +152,23 @@ def load_model(model, state_dict): return model +def restore_model(nemo_file, cpu_only=False): + # dummy_trainer = Trainer(devices=1, accelerator='cpu', strategy=NLPDDPStrategy()) + dummy_trainer = Trainer(devices=1, accelerator='cpu') + if cpu_only: + map_location = torch.device('cpu') + model_config = MegatronGPTModel.restore_from(nemo_file, trainer=dummy_trainer, return_config=True, map_location=map_location) + model_config.use_cpu_initialization = True + else: + model_config, map_location = None, None + return MegatronGPTModel.restore_from(nemo_file, trainer=dummy_trainer, override_config_path=model_config, map_location=map_location) -def convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True): +def convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True, cpu_only=False): if skip_if_output_exists and os.path.exists(output_nemo_file): logging.info(f"Output file already exists ({output_nemo_file}), skipping conversion...") return - dummy_trainer = Trainer(devices=1, accelerator='cpu') + nemo_model = restore_model(input_nemo_file, cpu_only=cpu_only) - nemo_model = MegatronGPTModel.restore_from(input_nemo_file, trainer=dummy_trainer) nemo_tokenizer_model = nemo_model.cfg.tokenizer.model nemo_state_dict = nemo_model.state_dict() mcore_state_dict = OrderedDict() @@ -168,11 +182,12 @@ def convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True): logging.info("registering artifact: tokenizer.model = " + nemo_tokenizer_model) mcore_model.register_artifact("tokenizer.model", nemo_tokenizer_model) + mcore_model.cfg.use_cpu_initialization = False mcore_model.save_to(output_nemo_file) logging.info(f"Done. Model saved to {output_nemo_file}") -def run_sanity_checks(nemo_file, mcore_file): +def run_sanity_checks(nemo_file, mcore_file, cpu_only=False): cfg = OmegaConf.load( os.path.join( os.path.dirname(__file__), @@ -182,9 +197,9 @@ def run_sanity_checks(nemo_file, mcore_file): cfg.trainer.precision = 'bf16' # change me dtype = torch.bfloat16 - trainer = MegatronTrainerBuilder(cfg).create_trainer() - nemo_model = MegatronGPTModel.restore_from(nemo_file, trainer=trainer).eval().to(dtype) - mcore_model = MegatronGPTModel.restore_from(mcore_file, trainer=trainer).eval().to(dtype) + + nemo_model = restore_model(nemo_file, cpu_only=cpu_only).eval().to(dtype) + mcore_model = restore_model(mcore_file, cpu_only=cpu_only).eval().to(dtype) logging.debug("*** Mcore model restored config") logging.debug(OmegaConf.to_yaml(mcore_model.cfg)) @@ -228,7 +243,9 @@ def run_sanity_checks(nemo_file, mcore_file): input_nemo_file = args.in_file output_nemo_file = args.out_file + cpu_only = args.cpu_only + os.makedirs(os.path.dirname(output_nemo_file), exist_ok=True) - convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True) + convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True, cpu_only=cpu_only) torch.cuda.empty_cache() - run_sanity_checks(input_nemo_file, output_nemo_file) + run_sanity_checks(input_nemo_file, output_nemo_file, cpu_only=cpu_only) From a9a4e59d9d3e92394e937432238f438199926bf9 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Fri, 20 Oct 2023 16:20:43 -0700 Subject: [PATCH 07/13] make script work for llama2 models Signed-off-by: Chen Cui --- .../nlp_language_modeling/convert_nemo_gpt_to_mcore.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py index 9f4a53d86a6d..8579483a8d50 100644 --- a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py +++ b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py @@ -93,14 +93,17 @@ def print_mcore_parameter_names(restore_from_path): def build_key_mapping(nemo_cfg): num_layers = nemo_cfg.num_layers has_bias = nemo_cfg.get("bias", True) + has_layernorm_bias = nemo_cfg.get("normalization", "layernorm") != "rmsnorm" # llama model uses rmsnorm which does not have bias model_str = 'model.module' if nemo_cfg.get('megatron_amp_O2', False) else 'model' # For GPT there is a 1:1 mapping of keys mcore_to_nemo_mapping = { f"{model_str}.embedding.word_embeddings.weight": "model.language_model.embedding.word_embeddings.weight", - f"{model_str}.decoder.final_layernorm.bias": "model.language_model.encoder.final_layernorm.bias", f"{model_str}.decoder.final_layernorm.weight": "model.language_model.encoder.final_layernorm.weight", } + if has_layernorm_bias: + mcore_to_nemo_mapping[f"{model_str}.decoder.final_layernorm.bias"] = "model.language_model.encoder.final_layernorm.bias" + if not nemo_cfg.get("share_embeddings_and_output_weights", True): mcore_to_nemo_mapping[f"{model_str}.output_layer.weight"] = "model.language_model.output_layer.weight" @@ -123,8 +126,8 @@ def build_key_mapping(nemo_cfg): f"{mcore_prefix}.{i}.self_attention.linear_qkv.{wb}": f"{nemo_prefix}.{i}.self_attention.query_key_value.{wb}", } ) - # layernorm layers always have bias! - for wb in ('weight', 'bias'): + # layernorm layers always have bias, but llama model uses rmsnorm which does not have bias + for wb in ('weight', 'bias') if has_layernorm_bias else ('weight',): mcore_to_nemo_mapping.update( { f"{mcore_prefix}.{i}.self_attention.linear_qkv.layer_norm_{wb}": f"{nemo_prefix}.{i}.input_layernorm.{wb}", From 0718d53ceec11ce0b9c5ae0e50851729ba2128a7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 20 Oct 2023 23:39:23 +0000 Subject: [PATCH 08/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../convert_nemo_gpt_to_mcore.py | 24 ++++++++++++++----- 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py index 80765f347158..69f105c9a7e5 100644 --- a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py +++ b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py @@ -53,8 +53,10 @@ def get_args(): "--out-file", type=str, default=None, required=True, help="Path to output mcore weights file (ends in .nemo)." ) parser.add_argument( - "--cpu-only", action="store_true", help="Load model in cpu only. Useful if the model cannot fit in GPU memory, " - "but this option makes the conversion script significantly slower." + "--cpu-only", + action="store_true", + help="Load model in cpu only. Useful if the model cannot fit in GPU memory, " + "but this option makes the conversion script significantly slower.", ) args = parser.parse_args() return args @@ -93,7 +95,9 @@ def print_mcore_parameter_names(restore_from_path): def build_key_mapping(nemo_cfg): num_layers = nemo_cfg.num_layers has_bias = nemo_cfg.get("bias", True) - has_layernorm_bias = nemo_cfg.get("normalization", "layernorm") != "rmsnorm" # llama model uses rmsnorm which does not have bias + has_layernorm_bias = ( + nemo_cfg.get("normalization", "layernorm") != "rmsnorm" + ) # llama model uses rmsnorm which does not have bias model_str = 'model.module' if nemo_cfg.get('megatron_amp_O2', False) else 'model' # For GPT there is a 1:1 mapping of keys @@ -102,7 +106,9 @@ def build_key_mapping(nemo_cfg): f"{model_str}.decoder.final_layernorm.weight": "model.language_model.encoder.final_layernorm.weight", } if has_layernorm_bias: - mcore_to_nemo_mapping[f"{model_str}.decoder.final_layernorm.bias"] = "model.language_model.encoder.final_layernorm.bias" + mcore_to_nemo_mapping[ + f"{model_str}.decoder.final_layernorm.bias" + ] = "model.language_model.encoder.final_layernorm.bias" if not nemo_cfg.get("share_embeddings_and_output_weights", True): mcore_to_nemo_mapping[f"{model_str}.output_layer.weight"] = "model.language_model.output_layer.weight" @@ -155,16 +161,22 @@ def load_model(model, state_dict): return model + def restore_model(nemo_file, cpu_only=False): # dummy_trainer = Trainer(devices=1, accelerator='cpu', strategy=NLPDDPStrategy()) dummy_trainer = Trainer(devices=1, accelerator='cpu') if cpu_only: map_location = torch.device('cpu') - model_config = MegatronGPTModel.restore_from(nemo_file, trainer=dummy_trainer, return_config=True, map_location=map_location) + model_config = MegatronGPTModel.restore_from( + nemo_file, trainer=dummy_trainer, return_config=True, map_location=map_location + ) model_config.use_cpu_initialization = True else: model_config, map_location = None, None - return MegatronGPTModel.restore_from(nemo_file, trainer=dummy_trainer, override_config_path=model_config, map_location=map_location) + return MegatronGPTModel.restore_from( + nemo_file, trainer=dummy_trainer, override_config_path=model_config, map_location=map_location + ) + def convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True, cpu_only=False): if skip_if_output_exists and os.path.exists(output_nemo_file): From 84c066eb2c12fcc676ee2600db3b12f2814fa3ba Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Fri, 20 Oct 2023 16:49:33 -0700 Subject: [PATCH 09/13] address code check Signed-off-by: Chen Cui --- scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py index 69f105c9a7e5..faedab23ed57 100644 --- a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py +++ b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py @@ -236,8 +236,10 @@ def run_sanity_checks(nemo_file, mcore_file, cpu_only=False): for mcore_param, nemo_param in build_key_mapping(nemo_model.cfg).items(): try: assert torch.allclose( - mcore_state_dict.pop(mcore_param), nemo_state_dict.pop(nemo_param) + mcore_state_dict[mcore_param], nemo_state_dict[nemo_param] ), f"❌ parameter {mcore_param} does not match" + mcore_state_dict.pop(mcore_param) + nemo_state_dict.pop(nemo_param) except KeyError: buffers = [k for k, v in mcore_model.named_buffers()] assert ( From 84b4484946a45c883c1372876bd4de061735b5e5 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Sat, 21 Oct 2023 09:15:53 -0700 Subject: [PATCH 10/13] remove trainer precision (was for old sanity check) Signed-off-by: Chen Cui --- .../convert_nemo_gpt_to_mcore.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py index faedab23ed57..5e4e6de54789 100644 --- a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py +++ b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py @@ -21,7 +21,6 @@ from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel -from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder from nemo.collections.nlp.parts.nlp_overrides import NLPDDPStrategy from nemo.utils import AppState, logging @@ -62,12 +61,12 @@ def get_args(): return args -def get_mcore_model_from_nemo_file(nemo_restore_from_path): +def get_mcore_model_from_nemo_file(nemo_restore_from_path, cpu_only=False): model_cfg = MegatronGPTModel.restore_from(nemo_restore_from_path, return_config=True) model_cfg.tokenizer.vocab_file = None model_cfg.tokenizer.merge_file = None model_cfg.mcore_gpt = True - model_cfg.use_cpu_initialization = True + model_cfg.use_cpu_initialization = cpu_only logging.info("*** initializing mcore model with the following config") logging.info(OmegaConf.to_yaml(model_cfg)) @@ -163,8 +162,7 @@ def load_model(model, state_dict): def restore_model(nemo_file, cpu_only=False): - # dummy_trainer = Trainer(devices=1, accelerator='cpu', strategy=NLPDDPStrategy()) - dummy_trainer = Trainer(devices=1, accelerator='cpu') + dummy_trainer = Trainer(devices=1, accelerator='cpu', strategy=NLPDDPStrategy()) if cpu_only: map_location = torch.device('cpu') model_config = MegatronGPTModel.restore_from( @@ -190,7 +188,7 @@ def convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True, cpu_o for mcore_param, nemo_param in build_key_mapping(nemo_model.cfg).items(): mcore_state_dict[mcore_param] = nemo_state_dict[nemo_param] - mcore_model = get_mcore_model_from_nemo_file(input_nemo_file) + mcore_model = get_mcore_model_from_nemo_file(input_nemo_file, cpu_only=cpu_only) mcore_model = load_model(mcore_model, mcore_state_dict) if nemo_model.cfg.tokenizer.model is not None: @@ -210,11 +208,8 @@ def run_sanity_checks(nemo_file, mcore_file, cpu_only=False): ) ) - cfg.trainer.precision = 'bf16' # change me - dtype = torch.bfloat16 - - nemo_model = restore_model(nemo_file, cpu_only=cpu_only).eval().to(dtype) - mcore_model = restore_model(mcore_file, cpu_only=cpu_only).eval().to(dtype) + nemo_model = restore_model(nemo_file, cpu_only=cpu_only).eval() + mcore_model = restore_model(mcore_file, cpu_only=cpu_only).eval() logging.debug("*** Mcore model restored config") logging.debug(OmegaConf.to_yaml(mcore_model.cfg)) From da8290dca369cc534008790c98ae79230136621c Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Mon, 23 Oct 2023 11:08:22 -0400 Subject: [PATCH 11/13] fix script for llama2 model Signed-off-by: Chen Cui --- .../convert_nemo_gpt_to_mcore.py | 37 ++++++++++++------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py index 5e4e6de54789..d4e5d1432768 100644 --- a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py +++ b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py @@ -17,7 +17,7 @@ from collections import OrderedDict import torch -from omegaconf import OmegaConf +from omegaconf import OmegaConf, open_dict from pytorch_lightning.trainer.trainer import Trainer from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel @@ -186,7 +186,13 @@ def convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True, cpu_o nemo_state_dict = nemo_model.state_dict() mcore_state_dict = OrderedDict() for mcore_param, nemo_param in build_key_mapping(nemo_model.cfg).items(): - mcore_state_dict[mcore_param] = nemo_state_dict[nemo_param] + if mcore_param.endswith("linear_fc1.weight"): + # in llama models, need to concat dense_h_to_4h.weight and dense_h_to_4h_2.weight for the corresponding linear_fc1.weight + second_param = nemo_param.replace("dense_h_to_4h.weight", "dense_h_to_4h_2.weight") + if second_param in nemo_state_dict: + mcore_state_dict[mcore_param] = torch.cat([nemo_state_dict[nemo_param], nemo_state_dict[second_param]], dim=0) + else: + mcore_state_dict[mcore_param] = nemo_state_dict[nemo_param] mcore_model = get_mcore_model_from_nemo_file(input_nemo_file, cpu_only=cpu_only) mcore_model = load_model(mcore_model, mcore_state_dict) @@ -201,12 +207,6 @@ def convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True, cpu_o def run_sanity_checks(nemo_file, mcore_file, cpu_only=False): - cfg = OmegaConf.load( - os.path.join( - os.path.dirname(__file__), - '../../examples/nlp/language_modeling/tuning/conf/megatron_gpt_peft_tuning_config.yaml', - ) - ) nemo_model = restore_model(nemo_file, cpu_only=cpu_only).eval() mcore_model = restore_model(mcore_file, cpu_only=cpu_only).eval() @@ -227,14 +227,23 @@ def run_sanity_checks(nemo_file, mcore_file, cpu_only=False): # check weights match mcore_state_dict = mcore_model.state_dict() nemo_state_dict = nemo_model.state_dict() - nemo_model.cfg.megatron_amp_O2 = False # we want build_key_mapping in the next line to not use O2 prefix + with open_dict(nemo_model.cfg): + nemo_model.cfg.megatron_amp_O2 = False # we want build_key_mapping in the next line to not use O2 prefix for mcore_param, nemo_param in build_key_mapping(nemo_model.cfg).items(): + # if nemo_param.endswith("dense_h_to_4h.weight"): + # # in llama models, need to concat dense_h_to_4h.weight and dense_h_to_4h_2.weight for the corresponding linear_fc1.weight + # second_param = nemo_param.replace("dense_h_to_4h.weight", "dense_h_to_4h_2.weight") + # if second_param in nemo_state_dict: + # mcore_state_dict[mcore_param] = torch.cat([nemo_state_dict[nemo_param], nemo_state_dict[second_param]], dim=0) try: - assert torch.allclose( - mcore_state_dict[mcore_param], nemo_state_dict[nemo_param] - ), f"❌ parameter {mcore_param} does not match" - mcore_state_dict.pop(mcore_param) - nemo_state_dict.pop(nemo_param) + mcore_weight = mcore_state_dict.pop(mcore_param) + nemo_weight = nemo_state_dict.pop(nemo_param) + if mcore_param.endswith("linear_fc1.weight"): + # linear_fc1.weight should map to concat(dense_h_to_4h.weight, dense_h_to_4h_2.weight) + # but build_key_mapping only maps it to dense_h_to_4h.weight, so we handle the concat here. + second_param = nemo_param.replace("dense_h_to_4h.weight", "dense_h_to_4h_2.weight") + nemo_weight = torch.cat([nemo_weight, nemo_state_dict.pop(second_param)]) + assert torch.allclose(mcore_weight, nemo_weight), f"❌ parameter {mcore_param} does not match" except KeyError: buffers = [k for k, v in mcore_model.named_buffers()] assert ( From 4dea37a71445d7f8b9431f2563b45f8d9d5fb960 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Mon, 23 Oct 2023 11:09:51 -0400 Subject: [PATCH 12/13] remove commented code Signed-off-by: Chen Cui --- scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py index d4e5d1432768..8e50c0d5f439 100644 --- a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py +++ b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py @@ -230,11 +230,6 @@ def run_sanity_checks(nemo_file, mcore_file, cpu_only=False): with open_dict(nemo_model.cfg): nemo_model.cfg.megatron_amp_O2 = False # we want build_key_mapping in the next line to not use O2 prefix for mcore_param, nemo_param in build_key_mapping(nemo_model.cfg).items(): - # if nemo_param.endswith("dense_h_to_4h.weight"): - # # in llama models, need to concat dense_h_to_4h.weight and dense_h_to_4h_2.weight for the corresponding linear_fc1.weight - # second_param = nemo_param.replace("dense_h_to_4h.weight", "dense_h_to_4h_2.weight") - # if second_param in nemo_state_dict: - # mcore_state_dict[mcore_param] = torch.cat([nemo_state_dict[nemo_param], nemo_state_dict[second_param]], dim=0) try: mcore_weight = mcore_state_dict.pop(mcore_param) nemo_weight = nemo_state_dict.pop(nemo_param) From 7a8a0957e7ca579d846a3adbdf7cff811edeb104 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 23 Oct 2023 15:11:19 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py index 8e50c0d5f439..23111f441101 100644 --- a/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py +++ b/scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py @@ -190,7 +190,9 @@ def convert(input_nemo_file, output_nemo_file, skip_if_output_exists=True, cpu_o # in llama models, need to concat dense_h_to_4h.weight and dense_h_to_4h_2.weight for the corresponding linear_fc1.weight second_param = nemo_param.replace("dense_h_to_4h.weight", "dense_h_to_4h_2.weight") if second_param in nemo_state_dict: - mcore_state_dict[mcore_param] = torch.cat([nemo_state_dict[nemo_param], nemo_state_dict[second_param]], dim=0) + mcore_state_dict[mcore_param] = torch.cat( + [nemo_state_dict[nemo_param], nemo_state_dict[second_param]], dim=0 + ) else: mcore_state_dict[mcore_param] = nemo_state_dict[nemo_param]