Skip to content

Commit

Permalink
fix: make torch.load safer with weights_only=True everywhere possible (
Browse files Browse the repository at this point in the history
  • Loading branch information
joanise authored Jan 16, 2025
1 parent 7d1403a commit 8a56acc
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 9 deletions.
17 changes: 11 additions & 6 deletions hfgl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,9 @@ def export(
from .utils import sizeof_fmt

orig_size = sizeof_fmt(os.path.getsize(model_path))
vocoder_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
vocoder_ckpt = torch.load(
model_path, map_location=torch.device("cpu"), weights_only=True
)
HiFiGAN.convert_ckpt_to_generator(vocoder_ckpt)
torch.save(vocoder_ckpt, output_path)
new_size = sizeof_fmt(os.path.getsize(output_path))
Expand Down Expand Up @@ -165,17 +167,20 @@ def synthesize(
from .utils import load_hifigan_from_checkpoint, synthesize_data

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
checkpoint = torch.load(generator_path, map_location=device)
data = torch.load(data_path, map_location=device)
checkpoint = torch.load(generator_path, map_location=device, weights_only=True)
# TODO figure out if we can convert our prepared data format to use weights only
data = torch.load(data_path, map_location=device, weights_only=False)
if time_oriented:
data = data.transpose(0, 1)
data_size = data.size()
config_n_mels = checkpoint["hyper_parameters"]["config"]["preprocessing"]["audio"]["n_mels"]
if (config_n_mels not in data_size):
config_n_mels = checkpoint["hyper_parameters"]["config"]["preprocessing"]["audio"][
"n_mels"
]
if config_n_mels not in data_size:
raise ValueError(
f"Your model expects a spectrogram of dimensions [K (Mel bands), T (frames)] where K == {config_n_mels} but you provided a tensor of size {data_size}"
)
if (data_size[0] != config_n_mels):
if data_size[0] != config_n_mels:
raise ValueError(
f"We expected the first dimension of your Mel spectrogram to correspond with the number of Mel bands declared by your model ({config_n_mels}). Instead, we found you model has the dimensions {data_size}. If your spectrogram is time-oriented, please re-run this command with the '--time-oriented' flag."
)
Expand Down
9 changes: 6 additions & 3 deletions hfgl/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ def __getitem__(self, index):
language,
f"spec-{self.output_sampling_rate}-{self.config.preprocessing.audio.spec_type}.pt",
]
)
),
weights_only=True,
) # [mel_bins, frames]
if self.finetune:
# If finetuning, use the synthesized spectral features
Expand All @@ -81,7 +82,8 @@ def __getitem__(self, index):
language,
f"spec-pred-{self.input_sampling_rate}-{self.config.preprocessing.audio.spec_type}.pt",
]
)
),
weights_only=True,
).transpose(0, 1)
else:
x = torch.load(
Expand All @@ -94,7 +96,8 @@ def __getitem__(self, index):
language,
f"spec-{self.input_sampling_rate}-{self.config.preprocessing.audio.spec_type}.pt",
]
)
),
weights_only=True,
) # [mel_bins, frames]
if self.use_segments:
x, y, y_mel = get_all_segments(
Expand Down

0 comments on commit 8a56acc

Please sign in to comment.