From 5ce9fa4f16cea403e51eff92d46b788d57cc076e Mon Sep 17 00:00:00 2001 From: pseeth Date: Wed, 31 Aug 2022 14:07:43 -0700 Subject: [PATCH] Adding kwargs. --- audiotools/__init__.py | 2 +- audiotools/ml/layers/base.py | 3 ++- setup.py | 2 +- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/audiotools/__init__.py b/audiotools/__init__.py index 1df5af1e..779e2b29 100644 --- a/audiotools/__init__.py +++ b/audiotools/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.4.1" +__version__ = "0.4.2" from .core import AudioSignal, STFTParams, Meter, util from . import metrics from . import data diff --git a/audiotools/ml/layers/base.py b/audiotools/ml/layers/base.py index 196ea4f6..7f4d24f1 100644 --- a/audiotools/ml/layers/base.py +++ b/audiotools/ml/layers/base.py @@ -148,6 +148,7 @@ def load_from_folder( folder: Path, package: bool = True, strict: bool = False, + **kwargs, ): folder = Path(folder) / cls.__name__.lower() model_pth = "package.pth" if package else "weights.pth" @@ -158,6 +159,6 @@ def load_from_folder( excluded = ["package.pth", "weights.pth"] files = [x for x in folder.glob("*") if x.is_file() and x.name not in excluded] for f in files: - extra_data[f.name] = torch.load(folder / f) + extra_data[f.name] = torch.load(folder / f, **kwargs) return model, extra_data diff --git a/setup.py b/setup.py index bb410534..d0091fe8 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ setup( name="audiotools", - version="0.4.1", + version="0.4.2", classifiers=[ "Intended Audience :: Developers", "Intended Audience :: Education",