Skip to content

Commit

Permalink
Maintenance fixes (#3398)
Browse files Browse the repository at this point in the history
* TorchFix issues and attrgetter

* parameter name fix
  • Loading branch information
ordabayevy authored Sep 19, 2024

Verified

This commit was signed with the committer’s verified signature.
Robbepop Robin Freyler
1 parent e914e19 commit 88ae262
Showing 14 changed files with 55 additions and 83 deletions.
2 changes: 1 addition & 1 deletion pyro/distributions/omt_mvn.py
Original file line number Diff line number Diff line change
@@ -68,7 +68,7 @@ def backward(ctx, grad_output):
diff_L_ab = 0.5 * sum_leftmost(g_ja * epsilon_jb + g_R_inv * z_ja, -2)

Sigma_inv = torch.mm(R_inv, R_inv.t())
V, D, _ = torch.svd(Sigma_inv + jitter)
V, D, _ = torch.linalg.svd(Sigma_inv + jitter)
D_outer = D.unsqueeze(-1) + D.unsqueeze(0)

expand_tuple = tuple([-1] * (z.dim() - 1) + [dim, dim])
2 changes: 1 addition & 1 deletion pyro/distributions/transforms/householder.py
Original file line number Diff line number Diff line change
@@ -29,7 +29,7 @@ def __init__(self, u_unnormed=None):
# Construct normalized vectors for Householder transform
def u(self):
u_unnormed = self.u_unnormed() if callable(self.u_unnormed) else self.u_unnormed
norm = torch.norm(u_unnormed, p=2, dim=-1, keepdim=True)
norm = torch.linalg.norm(u_unnormed, ord=2, dim=-1, keepdim=True)
return torch.div(u_unnormed, norm)

def _call(self, x):
4 changes: 2 additions & 2 deletions pyro/distributions/transforms/sylvester.py
Original file line number Diff line number Diff line change
@@ -92,11 +92,11 @@ def Q(self, x):
u = self.u()
partial_Q = torch.eye(
self.input_dim, dtype=x.dtype, layout=x.layout, device=x.device
) - 2.0 * torch.ger(u[0], u[0])
) - 2.0 * torch.outer(u[0], u[0])

for idx in range(1, self.u_unnormed.size(-2)):
partial_Q = torch.matmul(
partial_Q, torch.eye(self.input_dim) - 2.0 * torch.ger(u[idx], u[idx])
partial_Q, torch.eye(self.input_dim) - 2.0 * torch.outer(u[idx], u[idx])
)

return partial_Q
17 changes: 9 additions & 8 deletions pyro/infer/autoguide/effect.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from operator import attrgetter
from typing import Callable, Optional, Tuple, Union

import torch
@@ -14,7 +15,7 @@
from pyro.poutine.runtime import get_plates

from .initialization import init_to_feasible, init_to_mean
from .utils import deep_getattr, deep_setattr, helpful_support_errors
from .utils import deep_setattr, helpful_support_errors


class AutoMessengerMeta(type(GuideMessenger), type(PyroModule)):
@@ -175,8 +176,8 @@ def get_posterior(

def _get_params(self, name: str, prior: Distribution):
try:
loc = deep_getattr(self.locs, name)
scale = deep_getattr(self.scales, name)
loc = attrgetter(name)(self.locs)
scale = attrgetter(name)(self.scales)
return loc, scale
except AttributeError:
pass
@@ -287,10 +288,10 @@ def get_posterior(

def _get_params(self, name: str, prior: Distribution):
try:
loc = deep_getattr(self.locs, name)
scale = deep_getattr(self.scales, name)
loc = attrgetter(name)(self.locs)
scale = attrgetter(name)(self.scales)
if (self._hierarchical_sites is None) or (name in self._hierarchical_sites):
weight = deep_getattr(self.weights, name)
weight = attrgetter(name)(self.weights)
return loc, scale, weight
else:
return loc, scale
@@ -427,8 +428,8 @@ def get_posterior(

def _get_params(self, name: str, prior: Distribution):
try:
loc = deep_getattr(self.locs, name)
scale = deep_getattr(self.scales, name)
loc = attrgetter(name)(self.locs)
scale = attrgetter(name)(self.scales)
return loc, scale
except AttributeError:
pass
17 changes: 9 additions & 8 deletions pyro/infer/autoguide/gaussian.py
Original file line number Diff line number Diff line change
@@ -5,6 +5,7 @@
from abc import ABCMeta, abstractmethod
from collections import OrderedDict, defaultdict
from contextlib import ExitStack
from operator import attrgetter
from types import SimpleNamespace
from typing import Callable, Dict, Optional, Set, Tuple, Union

@@ -23,7 +24,7 @@

from .guides import AutoGuide
from .initialization import InitMessenger, init_to_feasible
from .utils import deep_getattr, deep_setattr, helpful_support_errors
from .utils import deep_setattr, helpful_support_errors


# Helper to dispatch to concrete subclasses of AutoGaussian, e.g.
@@ -287,8 +288,8 @@ def _transform_values(
for name, site in self._factors.items():
if site["is_observed"]:
continue
loc = deep_getattr(self.locs, name)
scale = deep_getattr(self.scales, name)
loc = attrgetter(name)(self.locs)
scale = attrgetter(name)(self.scales)
unconstrained = aux_values[name] * scale + loc

# Transform to constrained space.
@@ -335,7 +336,7 @@ def _setup_prototype(self, *args, **kwargs):
# Create sparse -> dense precision scatter indices.
self._dense_scatter = {}
for d, site in self._factors.items():
prec_sqrt_shape = deep_getattr(self.prec_sqrts, d).shape
prec_sqrt_shape = attrgetter(d)(self.prec_sqrts).shape
info_vec_shape = prec_sqrt_shape[:-1]
precision_shape = prec_sqrt_shape[:-1] + prec_sqrt_shape[-2:-1]
index1 = torch.zeros(info_vec_shape, dtype=torch.long)
@@ -425,8 +426,8 @@ def _dense_get_mvn(self):
flat_info_vec = torch.zeros(self._dense_size)
flat_precision = torch.zeros(self._dense_size**2)
for d, (index1, index2) in self._dense_scatter.items():
white_vec = deep_getattr(self.white_vecs, d)
prec_sqrt = deep_getattr(self.prec_sqrts, d)
white_vec = attrgetter(d)(self.white_vecs)
prec_sqrt = attrgetter(d)(self.prec_sqrts)
info_vec = (prec_sqrt @ white_vec[..., None])[..., 0]
precision = prec_sqrt @ prec_sqrt.transpose(-1, -2)
flat_info_vec.scatter_add_(0, index1, info_vec.reshape(-1))
@@ -505,8 +506,8 @@ def _sample_aux_values(self, *, temperature: float) -> Dict[str, torch.Tensor]:
batch_shape = torch.Size(
p.size for p in sorted(self._plates[d], key=lambda p: p.dim)
)
white_vec = deep_getattr(self.white_vecs, d)
prec_sqrt = deep_getattr(self.prec_sqrts, d)
white_vec = attrgetter(d)(self.white_vecs)
prec_sqrt = attrgetter(d)(self.prec_sqrts)
factors[d] = funsor.gaussian.Gaussian(
white_vec=white_vec.reshape(batch_shape + white_vec.shape[-1:]),
prec_sqrt=prec_sqrt.reshape(batch_shape + prec_sqrt.shape[-2:]),
7 changes: 4 additions & 3 deletions pyro/infer/autoguide/guides.py
Original file line number Diff line number Diff line change
@@ -20,6 +20,7 @@ def model():
import warnings
import weakref
from contextlib import ExitStack
from operator import attrgetter

import torch
from torch import nn
@@ -38,7 +39,7 @@ def model():
from pyro.poutine.util import site_is_subsample

from .initialization import InitMessenger, init_to_feasible, init_to_median
from .utils import _product, deep_getattr, deep_setattr, helpful_support_errors
from .utils import _product, deep_setattr, helpful_support_errors


def prototype_hide_fn(msg):
@@ -491,8 +492,8 @@ def _setup_prototype(self, *args, **kwargs):
)

def _get_loc_and_scale(self, name):
site_loc = deep_getattr(self.locs, name)
site_scale = deep_getattr(self.scales, name)
site_loc = attrgetter(name)(self.locs)
site_scale = attrgetter(name)(self.scales)
return site_loc, site_scale

def forward(self, *args, **kwargs):
19 changes: 10 additions & 9 deletions pyro/infer/autoguide/structured.py
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@

from collections import OrderedDict, defaultdict
from contextlib import ExitStack
from operator import attrgetter
from types import SimpleNamespace
from typing import Callable, Dict, Optional, Union

@@ -19,7 +20,7 @@

from .guides import AutoGuide
from .initialization import InitMessenger, init_to_feasible
from .utils import deep_getattr, deep_setattr, helpful_support_errors
from .utils import deep_setattr, helpful_support_errors


def _config_auxiliary(msg):
@@ -274,11 +275,11 @@ def get_deltas(self, save_params=None):

# Sample zero-mean blockwise independent Delta/Normal/MVN.
log_density = 0.0
loc = deep_getattr(self.locs, name)
loc = attrgetter(name)(self.locs)
zero = torch.zeros_like(loc)
conditional = self.conditionals[name]
if callable(conditional):
aux_value = deep_getattr(self.conds, name)()
aux_value = attrgetter(name)(self.conds)()
elif conditional == "delta":
aux_value = zero
elif conditional == "normal":
@@ -287,7 +288,7 @@ def get_deltas(self, save_params=None):
dist.Normal(zero, 1).to_event(1),
infer={"is_auxiliary": True},
)
scale = deep_getattr(self.scales, name)
scale = attrgetter(name)(self.scales)
aux_value = aux_value * scale
if compute_density:
log_density = (-scale.log()).expand_as(aux_value)
@@ -299,8 +300,8 @@ def get_deltas(self, save_params=None):
dist.Normal(zero, 1).to_event(1),
infer={"is_auxiliary": True},
)
scale = deep_getattr(self.scales, name)
scale_tril = deep_getattr(self.scale_trils, name)
scale = attrgetter(name)(self.scales)
scale_tril = attrgetter(name)(self.scale_trils)
aux_value = aux_value @ scale_tril.T * scale
if compute_density:
log_density = (
@@ -318,9 +319,9 @@ def get_deltas(self, save_params=None):
# Note: these shear transforms have no effect on the Jacobian
# determinant, and can therefore be excluded from the log_density
# computation below, even for nonlinear dep().
deps = deep_getattr(self.deps, name)
deps = attrgetter(name)(self.deps)
for upstream in self.dependencies.get(name, {}):
dep = deep_getattr(deps, upstream)
dep = attrgetter(upstream)(deps)
aux_value = aux_value + dep(aux_values[upstream])
aux_values[name] = aux_value

@@ -368,7 +369,7 @@ def forward(self, *args, **kwargs):
def median(self, *args, **kwargs):
result = {}
for name, site in self._sorted_sites:
loc = deep_getattr(self.locs, name).detach()
loc = attrgetter(name)(self.locs).detach()
shape = self._batch_shapes[name] + self._unconstrained_event_shapes[name]
loc = loc.reshape(shape)
result[name] = biject_to(site["fn"].support)(loc)
6 changes: 0 additions & 6 deletions pyro/infer/autoguide/utils.py
Original file line number Diff line number Diff line change
@@ -18,12 +18,6 @@ def _product(shape):
return result


def deep_getattr(obj, key):
for part in key.split("."):
obj = getattr(obj, part)
return obj


def deep_setattr(obj, key, val):
"""
Set an attribute `key` on the object. If any of the prefix attributes do
4 changes: 2 additions & 2 deletions pyro/ops/welford.py
Original file line number Diff line number Diff line change
@@ -33,7 +33,7 @@ def update(self, sample):
if self.diagonal:
self._m2 += delta_pre * delta_post
else:
self._m2 += torch.ger(delta_post, delta_pre)
self._m2 += torch.outer(delta_post, delta_pre)

def get_covariance(self, regularize=True):
if self.n_samples < 2:
@@ -72,7 +72,7 @@ def update(self, sample):
self._mean = self._mean + delta_pre / self.n_samples
delta_post = sample - self._mean
if self.head_size > 0:
self._m2_top = self._m2_top + torch.ger(
self._m2_top = self._m2_top + torch.outer(
delta_post[: self.head_size], delta_pre
)
else:
5 changes: 3 additions & 2 deletions pyro/primitives.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
from collections import OrderedDict
from contextlib import ExitStack, contextmanager
from inspect import isclass
from operator import attrgetter
from typing import Callable, Iterator, Optional, Sequence, Union

import torch
@@ -28,7 +29,7 @@
effectful,
)
from pyro.poutine.subsample_messenger import SubsampleMessenger
from pyro.util import deep_getattr, set_rng_seed # noqa: F401
from pyro.util import set_rng_seed # noqa: F401


def get_param_store() -> ParamStoreDict:
@@ -493,7 +494,7 @@ def module(
mod_name = _name
if _name in target_state_dict.keys():
if not is_param:
deep_getattr(nn_module, mod_name)._parameters[param_name] = (
attrgetter(mod_name)(nn_module)._parameters[param_name] = (
target_state_dict[_name]
)
else:
9 changes: 0 additions & 9 deletions pyro/util.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright (c) 2017-2019 Uber Technologies, Inc.
# SPDX-License-Identifier: Apache-2.0

import functools
import math
import numbers
import random
@@ -704,14 +703,6 @@ def ignore_experimental_warning():
yield


def deep_getattr(obj: object, name: str) -> Any:
"""
Python getattr() for arbitrarily deep attributes
Throws an AttributeError if bad attribute
"""
return functools.reduce(getattr, name.split("."), obj)


class timed:
def __enter__(self, timer=timeit.default_timer):
self.start = timer()
16 changes: 12 additions & 4 deletions tests/infer/test_sampling.py
Original file line number Diff line number Diff line change
@@ -78,9 +78,13 @@ def test_importance_guide(self):
self.model, guide=self.guide, num_samples=5000
).run()
marginal = EmpiricalMarginal(posterior)
assert_equal(0, torch.norm(marginal.mean - self.loc_mean).item(), prec=0.01)
assert_equal(
0, torch.norm(marginal.variance.sqrt() - self.loc_stddev).item(), prec=0.1
0, torch.linalg.norm(marginal.mean - self.loc_mean).item(), prec=0.01
)
assert_equal(
0,
torch.linalg.norm(marginal.variance.sqrt() - self.loc_stddev).item(),
prec=0.1,
)

@pytest.mark.init(rng_seed=0)
@@ -89,7 +93,11 @@ def test_importance_prior(self):
self.model, guide=None, num_samples=10000
).run()
marginal = EmpiricalMarginal(posterior)
assert_equal(0, torch.norm(marginal.mean - self.loc_mean).item(), prec=0.01)
assert_equal(
0, torch.norm(marginal.variance.sqrt() - self.loc_stddev).item(), prec=0.1
0, torch.linalg.norm(marginal.mean - self.loc_mean).item(), prec=0.01
)
assert_equal(
0,
torch.linalg.norm(marginal.variance.sqrt() - self.loc_stddev).item(),
prec=0.1,
)
28 changes: 1 addition & 27 deletions tests/ops/test_linalg.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,7 @@
import torch

from pyro.ops.linalg import rinverse
from tests.common import assert_close, assert_equal
from tests.common import assert_equal


@pytest.mark.parametrize(
@@ -35,29 +35,3 @@ def test_sym_rinverse(A, use_sym):
batched_A = A.unsqueeze(0).unsqueeze(0).expand(5, 4, d, d)
expected_A = torch.inverse(A).unsqueeze(0).unsqueeze(0).expand(5, 4, d, d)
assert_equal(rinverse(batched_A, sym=use_sym), expected_A, prec=1e-8)


# Tests migration from torch.triangular_solve -> torch.linalg.solve_triangular
@pytest.mark.filterwarnings("ignore:torch.triangular_solve is deprecated")
@pytest.mark.parametrize("upper", [False, True], ids=["lower", "upper"])
def test_triangular_solve(upper):
b = torch.randn(5, 6)
A = torch.randn(5, 5)
expected = torch.triangular_solve(b, A, upper=upper).solution
actual = torch.linalg.solve_triangular(A, b, upper=upper)
assert_close(actual, expected)
A = A.triu() if upper else A.tril()
assert_close(A @ actual, b)


# Tests migration from torch.triangular_solve -> torch.linalg.solve_triangular
@pytest.mark.filterwarnings("ignore:torch.triangular_solve is deprecated")
@pytest.mark.parametrize("upper", [False, True], ids=["lower", "upper"])
def test_triangular_solve_transpose(upper):
b = torch.randn(5, 6)
A = torch.randn(5, 5)
expected = torch.triangular_solve(b, A, upper=upper, transpose=True).solution
actual = torch.linalg.solve_triangular(A.T, b, upper=not upper)
assert_close(actual, expected)
A = A.triu() if upper else A.tril()
assert_close(A.T @ actual, b)
Loading

0 comments on commit 88ae262

Please sign in to comment.