Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make patch_dependency safer under dependency drift #3204

Merged
merged 1 commit into from
May 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 28 additions & 19 deletions pyro/distributions/torch_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,31 +3,40 @@

import functools
import math
import warnings
import weakref

import torch


def patch_dependency(target, root_module=torch):
parts = target.split(".")
assert parts[0] == root_module.__name__
module = root_module
for part in parts[1:-1]:
module = getattr(module, part)
name = parts[-1]
old_fn = getattr(module, name, None)
old_fn = getattr(old_fn, "_pyro_unpatched", old_fn) # ensure patching is idempotent

def decorator(new_fn):
try:
functools.update_wrapper(new_fn, old_fn)
except Exception:
for attr in functools.WRAPPER_ASSIGNMENTS:
if hasattr(old_fn, attr):
setattr(new_fn, attr, getattr(old_fn, attr))
new_fn._pyro_unpatched = old_fn
setattr(module, name, new_fn)
return new_fn
try:
parts = target.split(".")
assert parts[0] == root_module.__name__
module = root_module
for part in parts[1:-1]:
module = getattr(module, part)
name = parts[-1]
old_fn = getattr(module, name, None)
# Ensure patching is idempotent.
old_fn = getattr(old_fn, "_pyro_unpatched", old_fn)

def decorator(new_fn):
try:
functools.update_wrapper(new_fn, old_fn)
except Exception:
for attr in functools.WRAPPER_ASSIGNMENTS:
if hasattr(old_fn, attr):
setattr(new_fn, attr, getattr(old_fn, attr))
new_fn._pyro_unpatched = old_fn
setattr(module, name, new_fn)
return new_fn

except AttributeError:
warnings.warn(f"pyro patch_dependency target is stale: {target}")

def decorator(new_fn):
return new_fn

return decorator

Expand Down
4 changes: 2 additions & 2 deletions pyro/optim/pytorch_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
# Load all schedulers from PyTorch
# breaking change in torch >= 1.14: LRScheduler is new base class
if hasattr(torch.optim.lr_scheduler, "LRScheduler"):
_torch_scheduler_base = torch.optim.lr_scheduler.LRScheduler
_torch_scheduler_base = torch.optim.lr_scheduler.LRScheduler # type: ignore
else: # for torch < 1.13, _LRScheduler is base class
_torch_scheduler_base = torch.optim.lr_scheduler._LRScheduler
_torch_scheduler_base = torch.optim.lr_scheduler._LRScheduler # type: ignore

for _name, _Optim in torch.optim.lr_scheduler.__dict__.items():
if not isinstance(_Optim, type):
Expand Down