Skip to content

Commit

Permalink
Use weights_only for loading (#3427)
Browse files Browse the repository at this point in the history
Co-authored-by: Manuel Schmid <[email protected]>
  • Loading branch information
kit1980 and mashb1t authored Aug 3, 2024
1 parent 1a53e06 commit da3d4d0
Show file tree
Hide file tree
Showing 14 changed files with 21 additions and 21 deletions.
4 changes: 2 additions & 2 deletions extras/BLIP/models/blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ def is_url(url_or_filename):
def load_checkpoint(model,url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
checkpoint = torch.load(cached_file, map_location='cpu')
checkpoint = torch.load(cached_file, map_location='cpu', weights_only=True)
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location='cpu')
checkpoint = torch.load(url_or_filename, map_location='cpu', weights_only=True)
else:
raise RuntimeError('checkpoint url or path is invalid')

Expand Down
4 changes: 2 additions & 2 deletions extras/BLIP/models/blip_nlvr.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,9 @@ def blip_nlvr(pretrained='',**kwargs):
def load_checkpoint(model,url_or_filename):
if is_url(url_or_filename):
cached_file = download_cached_file(url_or_filename, check_hash=False, progress=True)
checkpoint = torch.load(cached_file, map_location='cpu')
checkpoint = torch.load(cached_file, map_location='cpu', weights_only=True)
elif os.path.isfile(url_or_filename):
checkpoint = torch.load(url_or_filename, map_location='cpu')
checkpoint = torch.load(url_or_filename, map_location='cpu', weights_only=True)
else:
raise RuntimeError('checkpoint url or path is invalid')
state_dict = checkpoint['model']
Expand Down
2 changes: 1 addition & 1 deletion extras/facexlib/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def init_detection_model(model_name, half=False, device='cuda', model_rootpath=N
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)

# TODO: clean pretrained model
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
load_net = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
# remove unnecessary 'module.'
for k, v in deepcopy(load_net).items():
if k.startswith('module.'):
Expand Down
2 changes: 1 addition & 1 deletion extras/facexlib/parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def init_parsing_model(model_name='bisenet', half=False, device='cuda', model_ro

model_path = load_file_from_url(
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
load_net = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
model.load_state_dict(load_net, strict=True)
model.eval()
model = model.to(device)
Expand Down
2 changes: 1 addition & 1 deletion extras/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path):
offload_device = torch.device('cpu')

use_fp16 = model_management.should_use_fp16(device=load_device)
ip_state_dict = torch.load(ip_adapter_path, map_location="cpu")
ip_state_dict = torch.load(ip_adapter_path, map_location="cpu", weights_only=True)
plus = "latents" in ip_state_dict["image_proj"]
cross_attention_dim = ip_state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[1]
sdxl = cross_attention_dim == 2048
Expand Down
2 changes: 1 addition & 1 deletion ldm_patched/ldm/modules/encoders/noise_aug_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self, *args, clip_stats_path=None, timestep_dim=256, **kwargs):
if clip_stats_path is None:
clip_mean, clip_std = torch.zeros(timestep_dim), torch.ones(timestep_dim)
else:
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu")
clip_mean, clip_std = torch.load(clip_stats_path, map_location="cpu", weights_only=True)
self.register_buffer("data_mean", clip_mean[None, :], persistent=False)
self.register_buffer("data_std", clip_std[None, :], persistent=False)
self.time_embed = Timestep(timestep_dim)
Expand Down
2 changes: 1 addition & 1 deletion ldm_patched/modules/sd1_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,7 @@ def load_embed(embedding_name, embedding_directory, embedding_size, embed_key=No
except:
embed_out = safe_load_embed_zip(embed_path)
else:
embed = torch.load(embed_path, map_location="cpu")
embed = torch.load(embed_path, map_location="cpu", weights_only=True)
except Exception as e:
print(traceback.format_exc())
print()
Expand Down
6 changes: 3 additions & 3 deletions ldm_patched/pfn/architecture/face/codeformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,15 +377,15 @@ def __init__(
)

if model_path is not None:
chkpt = torch.load(model_path, map_location="cpu")
chkpt = torch.load(model_path, map_location="cpu", weights_only=True)
if "params_ema" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params_ema"]
torch.load(model_path, map_location="cpu", weights_only=True)["params_ema"]
)
logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
elif "params" in chkpt:
self.load_state_dict(
torch.load(model_path, map_location="cpu")["params"]
torch.load(model_path, map_location="cpu", weights_only=True)["params"]
)
logger.info(f"vqgan is loaded from: {model_path} [params]")
else:
Expand Down
4 changes: 2 additions & 2 deletions ldm_patched/pfn/architecture/face/gfpgan_bilinear_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ def __init__(
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
torch.load(
decoder_load_path, map_location=lambda storage, loc: storage
)["params_ema"]
decoder_load_path, map_location=lambda storage, loc: storage,
weights_only=True)["params_ema"]
)
# fix decoder without updating params
if fix_decoder:
Expand Down
4 changes: 2 additions & 2 deletions ldm_patched/pfn/architecture/face/gfpganv1_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,8 @@ def __init__(
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
torch.load(
decoder_load_path, map_location=lambda storage, loc: storage
)["params_ema"]
decoder_load_path, map_location=lambda storage, loc: storage,
weights_only=True)["params_ema"]
)
# fix decoder without updating params
if fix_decoder:
Expand Down
4 changes: 2 additions & 2 deletions ldm_patched/pfn/architecture/face/gfpganv1_clean_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,8 +284,8 @@ def __init__(
if decoder_load_path:
self.stylegan_decoder.load_state_dict(
torch.load(
decoder_load_path, map_location=lambda storage, loc: storage
)["params_ema"]
decoder_load_path, map_location=lambda storage, loc: storage,
weights_only=True)["params_ema"]
)
# fix decoder without updating params
if fix_decoder:
Expand Down
2 changes: 1 addition & 1 deletion modules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def get_previewer(model):
if vae_approx_filename in VAE_approx_models:
VAE_approx_model = VAE_approx_models[vae_approx_filename]
else:
sd = torch.load(vae_approx_filename, map_location='cpu')
sd = torch.load(vae_approx_filename, map_location='cpu', weights_only=True)
VAE_approx_model = VAEApprox()
VAE_approx_model.load_state_dict(sd)
del sd
Expand Down
2 changes: 1 addition & 1 deletion modules/inpaint_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def patch(self, inpaint_head_model_path, inpaint_latent, inpaint_latent_mask, mo

if inpaint_head_model is None:
inpaint_head_model = InpaintHead()
sd = torch.load(inpaint_head_model_path, map_location='cpu')
sd = torch.load(inpaint_head_model_path, map_location='cpu', weights_only=True)
inpaint_head_model.load_state_dict(sd)

feed = torch.cat([
Expand Down
2 changes: 1 addition & 1 deletion modules/upscaler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def perform_upscale(img):

if model is None:
model_filename = downloading_upscale_model()
sd = torch.load(model_filename)
sd = torch.load(model_filename, weights_only=True)
sdo = OrderedDict()
for k, v in sd.items():
sdo[k.replace('residual_block_', 'RDB')] = v
Expand Down

0 comments on commit da3d4d0

Please sign in to comment.