Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: explicitly set weights_only to False during checkpoint loading to support PyTorch 2.6 #56

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/multimodal/combine_state_dicts.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def combine(input_files, module_prefixes, output_files):
zip(current_input_files, current_module_prefixes)
):
# initialize the combined state dict using the first provided input file
current_state_dict = torch.load(input_file)
current_state_dict = torch.load(input_file, weights_only=False)
if i == 0:
combined_state_dict = current_state_dict.copy()
combined_state_dict["model"] = dict()
Expand Down
2 changes: 1 addition & 1 deletion examples/multimodal/dataloader_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def train_valid_test_dataloaders_provider(train_val_test_num_samples):
)
if os.path.exists(data_save_name):
try:
dataset_state_dict = torch.load(data_save_name, map_location="cpu")
dataset_state_dict = torch.load(data_save_name, map_location="cpu", weights_only=False)
train_dataloader.restore_state_rank(dataset_state_dict["dataloader_state_dict"])
print(f"restored dataset state from {data_save_name}")
except Exception as e:
Expand Down
4 changes: 2 additions & 2 deletions examples/multimodal/nvlm/pp_checkpoint_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def split(input_dir, base_output_dir, input_pp, output_pp, num_tp, num_layers_pe
"""Split pipeline parallel size = 1 checkpoint to pipeline parallel size N."""
for tp in range(num_tp):
path = os.path.join(input_dir, f"mp_rank_0{tp}", "model_optim_rng.pt")
sd = torch.load(path)
sd = torch.load(path, weights_only=False)

if num_layers_per_pp_rank is None:
num_layers = sd["args"].num_layers
Expand Down Expand Up @@ -84,7 +84,7 @@ def combine(input_dir, base_output_dir, input_pp, output_pp, num_tp, num_layers_

for pp in range(input_pp):
path = os.path.join(input_dir, f"mp_rank_0{tp}_00{pp}", "model_optim_rng.pt")
sd = torch.load(path)
sd = torch.load(path, weights_only=False)

if pp == 0:
new_sd = sd.copy()
Expand Down
6 changes: 3 additions & 3 deletions megatron/core/dist_checkpointing/strategies/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def load_common(self, checkpoint_dir: Path):
"""
load_path = Path(checkpoint_dir) / COMMON_STATE_FNAME
try:
return torch.load(load_path, map_location='cpu')
return torch.load(load_path, map_location='cpu', weights_only=False)
except FileNotFoundError as e:
err_msg = f'Common file {load_path} does not exist'
ckpt_files = [f.name for f in checkpoint_dir.iterdir()]
Expand All @@ -95,12 +95,12 @@ def load_sharded_object(sh_obj: ShardedObject):
sh_obj.data = None
load_path = checkpoint_dir / f'{sh_obj.unique_key}.pt'
try:
loaded_obj = torch.load(load_path)
loaded_obj = torch.load(load_path, weights_only=False)
except FileNotFoundError as e:
# Backward compatible logic: previously the save format was incorrect
old_load_path = (checkpoint_dir / sh_obj.unique_key).with_suffix('.pt')
try:
loaded_obj = torch.load(old_load_path)
loaded_obj = torch.load(old_load_path, weights_only=False)
except FileNotFoundError:
err_msg = f'Object shard {load_path} not found'
obj_subdir = checkpoint_dir / sh_obj.key
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/export/trtllm/trtllm_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _load_scaling_factors(self, model_state_dict: dict) -> dict:
continue

val.seek(0)
extra_states = torch.load(val)
extra_states = torch.load(val, weights_only=False)

activation_scaling_factor_key = key.replace(mock_suffix, activation_scaling_suffix)
weight_scaling_factor_key = key.replace(mock_suffix, weight_scaling_suffix)
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ def _decode_extra_state(self, state):
return pickle.loads(state.detach().cpu().numpy().tobytes())
elif isinstance(state, io.BytesIO):
state.seek(0)
return torch.load(state, map_location="cuda")
return torch.load(state, map_location="cuda", weights_only=False)
else:
raise RuntimeError("Unsupported checkpoint format.")

Expand Down
2 changes: 1 addition & 1 deletion megatron/core/optimizer/distrib_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,7 +1602,7 @@ def load_parameter_state(self, filename: str, *, update_legacy_format=False):
"""
state_dict = None
if torch.distributed.get_rank(self.data_parallel_group) == 0:
state_dict = torch.load(filename)
state_dict = torch.load(filename, weights_only=False)

self.load_parameter_state_from_dp_zero(
state_dict, update_legacy_format=update_legacy_format
Expand Down
2 changes: 1 addition & 1 deletion megatron/core/optimizer/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,7 @@ def load_parameter_state(self, filename: str, *, update_legacy_format: bool = Fa

# Lazy loading checkpoint, state dict is needed only when DP rank = 0.
if torch.distributed.get_rank(optimizer.data_parallel_group) == 0 and states is None:
states = torch.load(filename)
states = torch.load(filename, weights_only=False)

state_dict = states[idx] if states else None
optimizer.load_parameter_state_from_dp_zero(
Expand Down
4 changes: 2 additions & 2 deletions megatron/legacy/model/biencoder_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def init_state_dict_from_bert(self):

# Load the checkpoint.
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
state_dict = torch.load(checkpoint_name, map_location='cpu', weights_only=False)
except ModuleNotFoundError:
from megatron.legacy.fp16_deprecated import loss_scaler
# For backward compatibility.
Expand All @@ -209,7 +209,7 @@ def init_state_dict_from_bert(self):
'megatron.fp16_deprecated.loss_scaler']
sys.modules['megatron.fp16.loss_scaler'] = sys.modules[
'megatron.fp16_deprecated.loss_scaler']
state_dict = torch.load(checkpoint_name, map_location='cpu')
state_dict = torch.load(checkpoint_name, map_location='cpu', weights_only=False)
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
except Exception:
Expand Down
2 changes: 1 addition & 1 deletion megatron/legacy/model/realm_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def init_state_dict_from_bert(self):
torch.distributed.get_rank(), checkpoint_name))

try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
state_dict = torch.load(checkpoint_name, map_location='cpu', weights_only=False)
except Exception:
raise ValueError("Could not load checkpoint")

Expand Down
2 changes: 1 addition & 1 deletion megatron/legacy/model/vision/esvit_swin_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ def flops(self):

def init_weights(self, pretrained='', pretrained_layers=[], verbose=True):
if os.path.isfile(pretrained):
pretrained_dict = torch.load(pretrained, map_location='cpu')
pretrained_dict = torch.load(pretrained, map_location='cpu', weights_only=False)
logging.info(f'=> loading pretrained model {pretrained}')
model_dict = self.state_dict()
pretrained_dict = {
Expand Down
6 changes: 3 additions & 3 deletions megatron/training/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ def _load_base_checkpoint(
else:
checkpoint_name = get_checkpoint_name(load_dir, iteration, release, return_base_dir=False)
try:
state_dict = torch.load(checkpoint_name, map_location='cpu')
state_dict = torch.load(checkpoint_name, map_location='cpu', weights_only=False)
except ModuleNotFoundError:
from megatron.legacy.fp16_deprecated import loss_scaler

Expand All @@ -880,7 +880,7 @@ def _load_base_checkpoint(
'megatron.legacy.fp16_deprecated.loss_scaler'
]
sys.modules['megatron.model'] = sys.modules['megatron.legacy.model']
state_dict = torch.load(checkpoint_name, map_location='cpu')
state_dict = torch.load(checkpoint_name, map_location='cpu', weights_only=False)
sys.modules.pop('fp16.loss_scaler', None)
sys.modules.pop('megatron.fp16.loss_scaler', None)
sys.modules.pop('megatron.model', None)
Expand Down Expand Up @@ -1346,7 +1346,7 @@ def load_biencoder_checkpoint(model, only_query_model=False,
print('global rank {} is loading checkpoint {}'.format(
torch.distributed.get_rank(), checkpoint_name))

state_dict = torch.load(checkpoint_name, map_location='cpu')
state_dict = torch.load(checkpoint_name, map_location='cpu', weights_only=False)
ret_state_dict = state_dict['model']

if only_query_model:
Expand Down
2 changes: 1 addition & 1 deletion tasks/ensemble_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def process_files(args):
for path in args.paths:
path = os.path.join(path, args.prediction_name)
try:
data = torch.load(path)
data = torch.load(path, weights_only=False)
for dataset in data:
name, d = dataset
predictions, labels, uid = d
Expand Down
2 changes: 1 addition & 1 deletion tasks/msdp/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def prompt_selection_for_knowledge_generation(
print("> loading tokenizer and encoder")
tokenizer = DPRQuestionEncoderTokenizer.from_pretrained(
'facebook/dpr-question_encoder-single-nq-base')
encoder = torch.load(model_path).cuda()
encoder = torch.load(model_path, weights_only=False).cuda()

print("> getting dialog embeddings")
with torch.no_grad():
Expand Down
6 changes: 4 additions & 2 deletions tasks/vision/finetune_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,15 @@ def finetune(
elif args.pretrained_checkpoint_type == 'external':
unwrap_model = utils.unwrap_model(model)
state_dict = torch.load(args.pretrained_checkpoint,
map_location="cpu")
map_location="cpu",
weights_only=False)
unwrap_model[0].module.backbone.load_state_dict(state_dict,
strict=False)
elif args.pretrained_checkpoint_type == 'constrastive':
unwrap_model = utils.unwrap_model(model)
state_dict = torch.load(args.pretrained_checkpoint,
map_location="cpu")
map_location="cpu",
weights_only=False)
state_dict = state_dict["model"]
state_dict = {k.replace("teacher.backbone.", ""): v
for k, v in state_dict.items()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def test_te_grouped_linear_torch_native(self, tmp_path_dist_ckpt, ep_size):
torch.save(model.state_dict(), ckpt_dir / f"model_ep{torch.distributed.get_rank()}.pt")

# Load checkpoint
state_dict = torch.load(ckpt_dir / f"model_ep{torch.distributed.get_rank()}.pt")
state_dict = torch.load(ckpt_dir / f"model_ep{torch.distributed.get_rank()}.pt", weights_only=False)
model.load_state_dict(state_dict)

Utils.destroy_model_parallel()
2 changes: 1 addition & 1 deletion tests/unit_tests/dist_checkpointing/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ def test_sharded_object_serialization(self, tmp_path_dist_ckpt):
load_state_dict = load(state_dict, ckpt_dir)
assert 'other_key' in load_state_dict
load_state_dict['other_key'].seek(0)
loaded_state = torch.load(load_state_dict['other_key'])
loaded_state = torch.load(load_state_dict['other_key'], weights_only=False)

assert loaded_state == {'some': 'dict'}

Expand Down
4 changes: 2 additions & 2 deletions tools/checkpoint/hybrid_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def main(args):

# load one of the model parallel ranks to get arguments
sample_model_file = os.path.join(input_model_dir, input_sub_models[0], "model_optim_rng.pt")
sample_model = torch.load(sample_model_file)
sample_model = torch.load(sample_model_file, weights_only=False)
print(f"Sample model {sample_model_file} is loaded.\n")

# input tensor and pipeline parallel size
Expand All @@ -261,7 +261,7 @@ def main(args):
dir_name += "_{:03d}".format(pp)
model_file = os.path.join(input_model_dir, dir_name, "model_optim_rng.pt")

tp_models.append(torch.load(model_file))
tp_models.append(torch.load(model_file), weights_only=False)
print(f"Model {model_file} is loaded.")

if input_tp_rank > 1:
Expand Down
4 changes: 2 additions & 2 deletions tools/checkpoint/loader_llama_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
if num_shards == 1:
# Not sharded
# (The sharded implementation would also work, but this is simpler.)
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu")
loaded = torch.load(os.path.join(input_base_path, "consolidated.00.pth"), map_location="cpu", weights_only=False)
else:
# Sharded
loaded = [
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu")
torch.load(os.path.join(input_base_path, f"consolidated.{i:02d}.pth"), map_location="cpu", weights_only=False)
for i in range(num_shards)
]
param_count = 0
Expand Down