diff --git a/pyro/distributions/torch_patch.py b/pyro/distributions/torch_patch.py index 5e5365a970..94a1f33b68 100644 --- a/pyro/distributions/torch_patch.py +++ b/pyro/distributions/torch_patch.py @@ -3,31 +3,40 @@ import functools import math +import warnings import weakref import torch def patch_dependency(target, root_module=torch): - parts = target.split(".") - assert parts[0] == root_module.__name__ - module = root_module - for part in parts[1:-1]: - module = getattr(module, part) - name = parts[-1] - old_fn = getattr(module, name, None) - old_fn = getattr(old_fn, "_pyro_unpatched", old_fn) # ensure patching is idempotent - - def decorator(new_fn): - try: - functools.update_wrapper(new_fn, old_fn) - except Exception: - for attr in functools.WRAPPER_ASSIGNMENTS: - if hasattr(old_fn, attr): - setattr(new_fn, attr, getattr(old_fn, attr)) - new_fn._pyro_unpatched = old_fn - setattr(module, name, new_fn) - return new_fn + try: + parts = target.split(".") + assert parts[0] == root_module.__name__ + module = root_module + for part in parts[1:-1]: + module = getattr(module, part) + name = parts[-1] + old_fn = getattr(module, name, None) + # Ensure patching is idempotent. + old_fn = getattr(old_fn, "_pyro_unpatched", old_fn) + + def decorator(new_fn): + try: + functools.update_wrapper(new_fn, old_fn) + except Exception: + for attr in functools.WRAPPER_ASSIGNMENTS: + if hasattr(old_fn, attr): + setattr(new_fn, attr, getattr(old_fn, attr)) + new_fn._pyro_unpatched = old_fn + setattr(module, name, new_fn) + return new_fn + + except AttributeError: + warnings.warn(f"pyro patch_dependency target is stale: {target}") + + def decorator(new_fn): + return new_fn return decorator diff --git a/pyro/optim/pytorch_optimizers.py b/pyro/optim/pytorch_optimizers.py index 05dfa4dbeb..42412a233d 100644 --- a/pyro/optim/pytorch_optimizers.py +++ b/pyro/optim/pytorch_optimizers.py @@ -36,9 +36,9 @@ # Load all schedulers from PyTorch # breaking change in torch >= 1.14: LRScheduler is new base class if hasattr(torch.optim.lr_scheduler, "LRScheduler"): - _torch_scheduler_base = torch.optim.lr_scheduler.LRScheduler + _torch_scheduler_base = torch.optim.lr_scheduler.LRScheduler # type: ignore else: # for torch < 1.13, _LRScheduler is base class - _torch_scheduler_base = torch.optim.lr_scheduler._LRScheduler + _torch_scheduler_base = torch.optim.lr_scheduler._LRScheduler # type: ignore for _name, _Optim in torch.optim.lr_scheduler.__dict__.items(): if not isinstance(_Optim, type):