diff --git a/opacus/utils/module_utils.py b/opacus/utils/module_utils.py index 28146cef..ee0437a9 100644 --- a/opacus/utils/module_utils.py +++ b/opacus/utils/module_utils.py @@ -89,6 +89,8 @@ def clone_module(module: nn.Module) -> nn.Module: """ Handy utility to clone an nn.Module. PyTorch doesn't always support copy.deepcopy(), so it is just easier to serialize the model to a BytesIO and read it from there. + When ``weights_only=False``, ``torch.load()`` uses "pickle" module implicity, which is known to be insecure. + Only load the model you trust. Args: module: The module to clone @@ -99,7 +101,7 @@ def clone_module(module: nn.Module) -> nn.Module: with io.BytesIO() as bytesio: torch.save(module, bytesio) bytesio.seek(0) - module_copy = torch.load(bytesio) + module_copy = torch.load(bytesio, weights_only=False) next_param = next( module.parameters(), None ) # Eg, InstanceNorm with affine=False has no params