Skip to content

Commit

Permalink
Set weights_only=True when using torch.load() (vllm-project#12366)
Browse files Browse the repository at this point in the history
Signed-off-by: Russell Bryant <[email protected]>
  • Loading branch information
russellb authored and LucasWilkinson committed Jan 24, 2025
1 parent cb32ece commit 97003fe
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 6 deletions.
2 changes: 1 addition & 1 deletion vllm/assets/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def image_embeds(self) -> torch.Tensor:
"""
image_path = get_vllm_public_assets(filename=f"{self.name}.pt",
s3_prefix=VLM_IMAGES_DIR)
return torch.load(image_path, map_location="cpu")
return torch.load(image_path, map_location="cpu", weights_only=True)
3 changes: 2 additions & 1 deletion vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ def from_local_checkpoint(
new_embeddings_tensor_path)
elif os.path.isfile(new_embeddings_bin_file_path):
embeddings = torch.load(new_embeddings_bin_file_path,
map_location=device)
map_location=device,
weights_only=True)

return cls.from_lora_tensors(
lora_model_id=get_lora_id()
Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def convert_bin_to_safetensor_file(
pt_filename: str,
sf_filename: str,
) -> None:
loaded = torch.load(pt_filename, map_location="cpu")
loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
if "state_dict" in loaded:
loaded = loaded["state_dict"]
shared = _shared_pointers(loaded)
Expand Down Expand Up @@ -381,7 +381,9 @@ def np_cache_weights_iterator(
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu")
state = torch.load(bin_file,
map_location="cpu",
weights_only=True)
for name, param in state.items():
param_path = os.path.join(np_folder, name)
with open(param_path, "wb") as f:
Expand Down Expand Up @@ -447,7 +449,7 @@ def pt_weights_iterator(
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu")
state = torch.load(bin_file, map_location="cpu", weights_only=True)
yield from state.items()
del state
torch.cuda.empty_cache()
Expand Down
3 changes: 2 additions & 1 deletion vllm/prompt_adapter/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def load_peft_weights(model_id: str,
adapters_weights = safe_load_file(filename, device=device)
else:
adapters_weights = torch.load(filename,
map_location=torch.device(device))
map_location=torch.device(device),
weights_only=True)

return adapters_weights

0 comments on commit 97003fe

Please sign in to comment.