Skip to content

Commit

Permalink
fix: explicitly set weights_only to False during checkpoint loading t…
Browse files Browse the repository at this point in the history
…o support PyTorch 2.6 (#56)
  • Loading branch information
mpashkovskii authored Feb 17, 2025
1 parent 3ec5701 commit f712ab8
Show file tree
Hide file tree
Showing 19 changed files with 30 additions and 28 deletions.
2 changes: 1 addition & 1 deletion examples/multimodal/combine_state_dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def combine(input_files, module_prefixes, output_files):
zip(current_input_files, current_module_prefixes)
):
# initialize the combined state dict using the first provided input file
current_state_dict = torch.load(input_file)
current_state_dict = torch.load(input_file, weights_only=False)
if i == 0:
combined_state_dict = current_state_dict.copy()
combined_state_dict["model"] = dict()
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal/dataloader_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def train_valid_test_dataloaders_provider(train_val_test_num_samples):
)
if os.path.exists(data_save_name):
try:
dataset_state_dict = torch.load(data_save_name, map_location="cpu")
dataset_state_dict = torch.load(data_save_name, map_location="cpu", weights_only=False)
train_dataloader.restore_state_rank(dataset_state_dict["dataloader_state_dict"])
print(f"restored dataset state from {data_save_name}")
except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions examples/multimodal/nvlm/pp_checkpoint_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def split(input_dir, base_output_dir, input_pp, output_pp, num_tp, num_layers_pe
"""Split pipeline parallel size = 1 checkpoint to pipeline parallel size N."""
for tp in range(num_tp):
path = os.path.join(input_dir, f"mp_rank_0{tp}", "model_optim_rng.pt")
sd = torch.load(path)
sd = torch.load(path, weights_only=False)

if num_layers_per_pp_rank is None:
num_layers = sd["args"].num_layers
Expand Down Expand Up @@ -84,7 +84,7 @@ def combine(input_dir, base_output_dir, input_pp, output_pp, num_tp, num_layers_

for pp in range(input_pp):
path = os.path.join(input_dir, f"mp_rank_0{tp}_00{pp}", "model_optim_rng.pt")
sd = torch.load(path)
sd = torch.load(path, weights_only=False)

if pp == 0:
new_sd = sd.copy()
Expand Down
6 changes: 3 additions & 3 deletions megatron/core/dist_checkpointing/strategies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def load_common(self, checkpoint_dir: Path):
"""
load_path = Path(checkpoint_dir) / COMMON_STATE_FNAME
try:
return torch.load(load_path, map_location='cpu')
return torch.load(load_path, map_location='cpu', weights_only=False)
except FileNotFoundError as e:
err_msg = f'Common file {load_path} does not exist'
ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
Expand All @@ -95,12 +95,12 @@ def load_sharded_object(sh_obj: ShardedObject):
sh_obj.data = None
load_path = checkpoint_dir / f'{sh_obj.unique_key}.pt'
try:
loaded_obj = torch.load(load_path)
loaded_obj = torch.load(load_path, weights_only=False)
except FileNotFoundError as e:
# Backward compatible logic: previously the save format was incorrect
old_load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
try:
loaded_obj = torch.load(old_load_path)
loaded_obj = torch.load(old_load_path, weights_only=False)
except FileNotFoundError:
err_msg = f'Object shard {load_path} not found'
obj_subdir = checkpoint_dir / sh_obj.key
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/export/trtllm/trtllm_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _load_scaling_factors(self, model_state_dict: dict) -> dict:
continue

val.seek(0)
extra_states = torch.load(val)
extra_states = torch.load(val, weights_only=False)

activation_scaling_factor_key = key.replace(mock_suffix, activation_scaling_suffix)
weight_scaling_factor_key = key.replace(mock_suffix, weight_scaling_suffix)
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ def _decode_extra_state(self, state):
return pickle.loads(state.detach().cpu().numpy().tobytes())
elif isinstance(state, io.BytesIO):
state.seek(0)
return torch.load(state, map_location="cuda")
return torch.load(state, map_location="cuda", weights_only=False)
else:
raise RuntimeError("Unsupported checkpoint format.")

Expand Down
2 changes: 1 addition & 1 deletion megatron/core/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,7 +1602,7 @@ def load_parameter_state(self, filename: str, *, update_legacy_format=False):
"""
state_dict = None
if torch.distributed.get_rank(self.data_parallel_group) == 0:
state_dict = torch.load(filename)
state_dict = torch.load(filename, weights_only=False)

self.load_parameter_state_from_dp_zero(
state_dict, update_legacy_format=update_legacy_format
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,7 @@ def load_parameter_state(self, filename: str, *, update_legacy_format: bool = Fa

# Lazy loading checkpoint, state dict is needed only when DP rank = 0.
if torch.distributed.get_rank(optimizer.data_parallel_group) == 0 and states is None:
states = torch.load(filename)
states = torch.load(filename, weights_only=False)

state_dict = states[idx] if states else None
optimizer.load_parameter_state_from_dp_zero(
Expand Down
4 changes: 2 additions & 2 deletions megatron/legacy/model/biencoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def init_state_dict_from_bert(self):

# Load the checkpoint.
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
state_dict = torch.load(checkpoint_name, map_location='cpu', weights_only=False)
except ModuleNotFoundError:
from megatron.legacy.fp16_deprecated import loss_scaler
# For backward compatibility.
Expand All @@ -209,7 +209,7 @@ def init_state_dict_from_bert(self):
'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
state_dict = torch.load(checkpoint_name, map_location='cpu', weights_only=False)
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion megatron/legacy/model/realm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def init_state_dict_from_bert(self):
torch.distributed.get_rank(), checkpoint_name))

try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
state_dict = torch.load(checkpoint_name, map_location='cpu', weights_only=False)
except Exception:
raise ValueError("Could not load checkpoint")

Expand Down
2 changes: 1 addition & 1 deletion megatron/legacy/model/vision/esvit_swin_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def flops(self):

def init_weights(self, pretrained='', pretrained_layers=[], verbose=True):
if os.path.isfile(pretrained):
pretrained_dict = torch.load(pretrained, map_location='cpu')
pretrained_dict = torch.load(pretrained, map_location='cpu', weights_only=False)
logging.info(f'=> loading pretrained model {pretrained}')
model_dict = self.state_dict()
pretrained_dict = {
Expand Down
6 changes: 3 additions & 3 deletions megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def _load_base_checkpoint(
else:
checkpoint_name = get_checkpoint_name(load_dir, iteration, release, return_base_dir=False)
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
state_dict = torch.load(checkpoint_name, map_location='cpu', weights_only=False)
except ModuleNotFoundError:
from megatron.legacy.fp16_deprecated import loss_scaler

Expand All @@ -880,7 +880,7 @@ def _load_base_checkpoint(
'megatron.legacy.fp16_deprecated.loss_scaler'
]
sys.modules['megatron.model'] = sys.modules['megatron.legacy.model']
state_dict = torch.load(checkpoint_name, map_location='cpu')
state_dict = torch.load(checkpoint_name, map_location='cpu', weights_only=False)
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
sys.modules.pop('megatron.model', None)
Expand Down Expand Up @@ -1346,7 +1346,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))

state_dict = torch.load(checkpoint_name, map_location='cpu')
state_dict = torch.load(checkpoint_name, map_location='cpu', weights_only=False)
ret_state_dict = state_dict['model']

if only_query_model:
Expand Down
2 changes: 1 addition & 1 deletion tasks/ensemble_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def process_files(args):
for path in args.paths:
path = os.path.join(path, args.prediction_name)
try:
data = torch.load(path)
data = torch.load(path, weights_only=False)
for dataset in data:
name, d = dataset
predictions, labels, uid = d
Expand Down
2 changes: 1 addition & 1 deletion tasks/msdp/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def prompt_selection_for_knowledge_generation(
print("> loading tokenizer and encoder")
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
'facebook/dpr-question_encoder-single-nq-base')
encoder = torch.load(model_path).cuda()
encoder = torch.load(model_path, weights_only=False).cuda()

print("> getting dialog embeddings")
with torch.no_grad():
Expand Down
6 changes: 4 additions & 2 deletions tasks/vision/finetune_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,15 @@ def finetune(
elif args.pretrained_checkpoint_type == 'external':
unwrap_model = utils.unwrap_model(model)
state_dict = torch.load(args.pretrained_checkpoint,
map_location="cpu")
map_location="cpu",
weights_only=False)
unwrap_model[0].module.backbone.load_state_dict(state_dict,
strict=False)
elif args.pretrained_checkpoint_type == 'constrastive':
unwrap_model = utils.unwrap_model(model)
state_dict = torch.load(args.pretrained_checkpoint,
map_location="cpu")
map_location="cpu",
weights_only=False)
state_dict = state_dict["model"]
state_dict = {k.replace("teacher.backbone.", ""): v
for k, v in state_dict.items()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def test_te_grouped_linear_torch_native(self, tmp_path_dist_ckpt, ep_size):
torch.save(model.state_dict(), ckpt_dir / f"model_ep{torch.distributed.get_rank()}.pt")

# Load checkpoint
state_dict = torch.load(ckpt_dir / f"model_ep{torch.distributed.get_rank()}.pt")
state_dict = torch.load(ckpt_dir / f"model_ep{torch.distributed.get_rank()}.pt", weights_only=False)
model.load_state_dict(state_dict)

Utils.destroy_model_parallel()
2 changes: 1 addition & 1 deletion tests/unit_tests/dist_checkpointing/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def test_sharded_object_serialization(self, tmp_path_dist_ckpt):
load_state_dict = load(state_dict, ckpt_dir)
assert 'other_key' in load_state_dict
load_state_dict['other_key'].seek(0)
loaded_state = torch.load(load_state_dict['other_key'])
loaded_state = torch.load(load_state_dict['other_key'], weights_only=False)

assert loaded_state == {'some': 'dict'}

Expand Down
4 changes: 2 additions & 2 deletions tools/checkpoint/hybrid_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def main(args):

# load one of the model parallel ranks to get arguments
sample_model_file = os.path.join(input_model_dir, input_sub_models[0], "model_optim_rng.pt")
sample_model = torch.load(sample_model_file)
sample_model = torch.load(sample_model_file, weights_only=False)
print(f"Sample model {sample_model_file} is loaded.\n")

# input tensor and pipeline parallel size
Expand All @@ -261,7 +261,7 @@ def main(args):
dir_name += "_{:03d}".format(pp)
model_file = os.path.join(input_model_dir, dir_name, "model_optim_rng.pt")

tp_models.append(torch.load(model_file))
tp_models.append(torch.load(model_file), weights_only=False)
print(f"Model {model_file} is loaded.")

if input_tp_rank > 1:
Expand Down
4 changes: 2 additions & 2 deletions tools/checkpoint/loader_llama_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
if num_shards == 1:
# Not sharded
# (The sharded implementation would also work, but this is simpler.)
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu", weights_only=False)
else:
# Sharded
loaded = [
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu", weights_only=False)
for i in range(num_shards)
]
param_count = 0
Expand Down

0 comments on commit f712ab8

Please sign in to comment.