Skip to content

Commit

Permalink
add one more sanity check to make sure there is no unexpected keys in…
Browse files Browse the repository at this point in the history
… state dict

Signed-off-by: Chen Cui <[email protected]>
  • Loading branch information
cuichenx committed Oct 20, 2023
1 parent 68ca6ad commit 637b72d
Showing 1 changed file with 12 additions and 6 deletions.
18 changes: 12 additions & 6 deletions scripts/nlp_language_modeling/convert_nemo_gpt_to_mcore.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,10 @@ def print_mcore_parameter_names(restore_from_path):
print("*********")


def build_key_mapping(nemo_cfg, use_O2_prefix=None):
def build_key_mapping(nemo_cfg):
num_layers = nemo_cfg.num_layers
has_bias = nemo_cfg.get("bias", True)
if use_O2_prefix is None:
use_O2_prefix = nemo_cfg.get('megatron_amp_O2', False)
model_str = 'model.module' if use_O2_prefix else 'model'
model_str = 'model.module' if nemo_cfg.get('megatron_amp_O2', False) else 'model'

# For GPT there is a 1:1 mapping of keys
mcore_to_nemo_mapping = {
Expand Down Expand Up @@ -204,18 +202,26 @@ def run_sanity_checks(nemo_file, mcore_file):
# check weights match
mcore_state_dict = mcore_model.state_dict()
nemo_state_dict = nemo_model.state_dict()
for mcore_param, nemo_param in build_key_mapping(nemo_model.cfg, use_O2_prefix=False).items():
nemo_model.cfg.megatron_amp_O2 = False # we want build_key_mapping in the next line to not use O2 prefix
for mcore_param, nemo_param in build_key_mapping(nemo_model.cfg).items():
try:
assert torch.allclose(
mcore_state_dict[mcore_param], nemo_state_dict[nemo_param]
mcore_state_dict.pop(mcore_param), nemo_state_dict.pop(nemo_param)
), f"❌ parameter {mcore_param} does not match"
except KeyError:
buffers = [k for k, v in mcore_model.named_buffers()]
assert (
mcore_param in buffers or mcore_param.replace('model.', 'model.module.', 1) in buffers
), f"❌ parameter {mcore_param} is not found in the state dict or named_buffers()"
nemo_state_dict.pop(nemo_param)

logging.info("✅ Weights match")

# check for unexpected weights in state dict
assert len(nemo_state_dict)==0, f"❌ unexpected items in nemo_state_dict: {nemo_state_dict}"
assert len([k for k in mcore_state_dict if not k.endswith('_extra_state')])==0, f"❌ unexpected items in mcore_state_dict: {mcore_state_dict}"
logging.info("✅ No unexpected weights in state dicts")


if __name__ == '__main__':
args = get_args()
Expand Down

0 comments on commit 637b72d

Please sign in to comment.