Skip to content

Commit

Permalink
add AutoSurrogateLikelihoodDAIS (#1434)
Browse files Browse the repository at this point in the history
* initial commit of autosldais

* add usage example

* more docs

* tweak

* fix typo

* tweak docstring

* improve docstring

* clean-up test and add expose_types to block

* use expose_types

* fewer_steps

* lint
  • Loading branch information
martinjankowiak authored Jun 21, 2022
1 parent 89ba76a commit 3f1a9d7
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 1 deletion.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete l
- [AutoDelta](https://num.pyro.ai/en/latest/autoguide.html#autodelta) is used for computing point estimates via MAP (maximum a posteriori estimation). See [here](https://github.com/pyro-ppl/numpyro/blob/bbe1f879eede79eebfdd16dfc49c77c4d1fc727c/examples/zero_inflated_poisson.py#L101) for example usage.
- [AutoBNAFNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoBNAFNormal) and [AutoIAFNormal](https://num.pyro.ai/en/latest/autoguide.html#autoiafnormal) offer flexible variational distributions parameterized by normalizing flows.
- [AutoDAIS](https://num.pyro.ai/en/latest/autoguide.html#autodais) is a powerful variational inference algorithm that leverages HMC. It can be a good choice for dealing with highly correlated posteriors but may be computationally expensive depending on the nature of the model.
- [AutoSurrogateLikelihoodDAIS](https://num.pyro.ai/en/latest/autoguide.html#autosurrogatelikelihooddais) is a powerful variational inference algorithm that leverages HMC and that supports data subsampling.
- [AutoSemiDAIS](https://num.pyro.ai/en/latest/autoguide.html#autosemidais) constructs a posterior approximation like [AutoDAIS](https://num.pyro.ai/en/latest/autoguide.html#autodais) for local latent variables but provides support for data subsampling during ELBO training by utilizing a parametric guide for global latent variables.
- [AutoLaplaceApproximation](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoLaplaceApproximation) can be used to compute a Laplace approximation.

Expand Down
9 changes: 9 additions & 0 deletions docs/source/autoguide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ We provide a brief overview of the automatically generated guides available in N
* `AutoDelta <https://num.pyro.ai/en/latest/autoguide.html#autodelta>`_ is used for computing point estimates via MAP (maximum a posteriori estimation). See `here <https://github.com/pyro-ppl/numpyro/blob/bbe1f879eede79eebfdd16dfc49c77c4d1fc727c/examples/zero_inflated_poisson.py#L101>`_ for example usage.
* `AutoBNAFNormal <https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoBNAFNormal>`_ and `AutoIAFNormal <https://num.pyro.ai/en/latest/autoguide.html#autoiafnormal>`_ offer flexible variational distributions parameterized by normalizing flows.
* `AutoDAIS <https://num.pyro.ai/en/latest/autoguide.html#autodais>`_ is a powerful variational inference algorithm that leverages HMC. It can be a good choice for dealing with highly correlated posteriors but may be computationally expensive depending on the nature of the model.
* `AutoSurrogateLikelihoodDAIS <https://num.pyro.ai/en/latest/autoguide.html#autosurrogatelikelihooddais>`_ is a powerful variational inference algorithm that leverages HMC and that supports data subsampling.
* `AutoSemiDAIS <https://num.pyro.ai/en/latest/autoguide.html#autosemidais>`_ constructs a posterior approximation like `AutoDAIS <https://num.pyro.ai/en/latest/autoguide.html#autodais>`_ for local latent variables but provides support for data subsampling during ELBO training by utilizing a parametric guide for global latent variables.
* `AutoLaplaceApproximation <https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoLaplaceApproximation>`_ can be used to compute a Laplace approximation.

Expand Down Expand Up @@ -108,3 +109,11 @@ AutoSemiDAIS
:undoc-members:
:show-inheritance:
:member-order: bysource

AutoSurrogateLikelihoodDAIS
---------------------------
.. autoclass:: numpyro.infer.autoguide.AutoSurrogateLikelihoodDAIS
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
5 changes: 4 additions & 1 deletion numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ class block(Messenger):
:param callable hide_fn: function which when given a dictionary containing
site-level metadata returns whether it should be blocked.
:param list hide: list of site names to hide.
:param list expose_types: list of site types to expose, e.g. `['param']`.
**Example:**
Expand All @@ -259,11 +260,13 @@ class block(Messenger):
>>> assert 'b' in trace_block_a
"""

def __init__(self, fn=None, hide_fn=None, hide=None):
def __init__(self, fn=None, hide_fn=None, hide=None, expose_types=None):
if hide_fn is not None:
self.hide_fn = hide_fn
elif hide is not None:
self.hide_fn = lambda msg: msg.get("name") in hide
elif expose_types is not None:
self.hide_fn = lambda msg: msg.get("type") not in expose_types
else:
self.hide_fn = lambda msg: True
super(block, self).__init__(fn)
Expand Down
203 changes: 203 additions & 0 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"AutoIAFNormal",
"AutoDelta",
"AutoSemiDAIS",
"AutoSurrogateLikelihoodDAIS",
]


Expand Down Expand Up @@ -865,6 +866,208 @@ def _single_sample(_rng_key):
return _single_sample(rng_key)


class AutoSurrogateLikelihoodDAIS(AutoDAIS):
"""
This implementation of :class:`AutoSurrogateLikelihoodDAIS` provides a
mini-batchable family of variational distributions as described in [1].
It combines a user-provided surrogate likelihood with Differentiable Annealed
Importance Sampling (DAIS) [2, 3]. It is not applicable to models with local
latent variables (see :class:`AutoSemiDAIS`), but unlike :class:`AutoDAIS`, it
*can* be used in conjunction with data subsampling.
**Reference:**
1. *Surrogate likelihoods for variational annealed importance sampling*,
Martin Jankowiak, Du Phan
2. *MCMC Variational Inference via Uncorrected Hamiltonian Annealing*,
Tomas Geffner, Justin Domke
3. *Differentiable Annealed Importance Sampling and the Perils of Gradient Noise*,
Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse
Usage::
# logistic regression model for data {X, Y}
def model(X, Y):
theta = numpyro.sample(
"theta", dist.Normal(jnp.zeros(2), jnp.ones(2)).to_event(1)
)
with numpyro.plate("N", 100, subsample_size=10):
X_batch = numpyro.subsample(X, event_dim=1)
Y_batch = numpyro.subsample(Y, event_dim=0)
numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch)
# surrogate model defined by prior and surrogate likelihood.
# a convenient choice for specifying the latter is to compute the likelihood on
# a randomly chosen data subset (here {X_surr, Y_surr} of size 20) and then use
# handlers.scale to scale the log likelihood by a vector of learnable weights.
def surrogate_model(X_surr, Y_surr):
theta = numpyro.sample(
"theta", dist.Normal(jnp.zeros(2), jnp.ones(2)).to_event(1)
)
omegas = numpyro.param(
"omegas", 5.0 * jnp.ones(20), constraint=dist.constraints.positive
)
with numpyro.plate("N", 20), numpyro.handlers.scale(scale=omegas):
numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_surr.T), obs=Y_surr)
guide = AutoSurrogateLikelihoodDAIS(model, surrogate_model)
svi = SVI(model, guide, ...)
:param callable model: A NumPyro model.
:param callable surrogate_model: A NumPyro model that is used as a surrogate model
for guiding the HMC dynamics that define the variational distribution. In particular
`surrogate_model` should contain the same prior as `model` but should contain a
cheap-to-evaluate parametric ansatz for the likelihood. A simple ansatz for the latter
involves computing the likelihood for a fixed subset of the data and scaling the resulting
log likelihood by a learnable vector of positive weights. See the usage example above.
:param str prefix: A prefix that will be prefixed to all param internal sites.
:param int K: A positive integer that controls the number of HMC steps used.
Defaults to 4.
:param str base_dist: Controls whether the base Normal variational distribution
is parameterized by a "diagonal" covariance matrix or a full-rank covariance
matrix parameterized by a lower-triangular "cholesky" factor. Defaults to "diagonal".
:param float eta_init: The initial value of the step size used in HMC. Defaults
to 0.01.
:param float eta_max: The maximum value of the learnable step size used in HMC.
Defaults to 0.1.
:param float gamma_init: The initial value of the learnable damping factor used
during partial momentum refreshments in HMC. Defaults to 0.9.
:param callable init_loc_fn: A per-site initialization function.
See :ref:`init_strategy` section for available functions.
:param float init_scale: Initial scale for the standard deviation of
the base variational distribution for each (unconstrained transformed)
latent variable. Defaults to 0.1.
"""

def __init__(
self,
model,
surrogate_model,
*,
K=4,
eta_init=0.01,
eta_max=0.1,
gamma_init=0.9,
prefix="auto",
base_dist="diagonal",
init_loc_fn=init_to_uniform,
init_scale=0.1,
):
super().__init__(
model,
K=K,
eta_init=eta_init,
eta_max=eta_max,
gamma_init=gamma_init,
prefix=prefix,
init_loc_fn=init_loc_fn,
init_scale=init_scale,
base_dist=base_dist,
)

self.surrogate_model = surrogate_model

def _setup_prototype(self, *args, **kwargs):
AutoContinuous._setup_prototype(self, *args, **kwargs)

rng_key = numpyro.prng_key()

with numpyro.handlers.block():
(_, self._surrogate_potential_fn, _, _) = initialize_model(
rng_key,
self.surrogate_model,
init_strategy=self.init_loc_fn,
dynamic_args=False,
model_args=(),
model_kwargs={},
)

def _sample_latent(self, *args, **kwargs):
def blocked_surrogate_model(x):
x_unpack = self._unpack_latent(x)
with numpyro.handlers.block(expose_types=["param"]):
return -self._surrogate_potential_fn(x_unpack)

eta0 = numpyro.param(
"{}_eta0".format(self.prefix),
self.eta_init,
constraint=constraints.interval(0, self.eta_max),
)
eta_coeff = numpyro.param("{}_eta_coeff".format(self.prefix), 0.0)

gamma = numpyro.param(
"{}_gamma".format(self.prefix),
self.gamma_init,
constraint=constraints.interval(0, 1),
)
betas = numpyro.param(
"{}_beta_increments".format(self.prefix),
jnp.ones(self.K),
constraint=constraints.positive,
)
betas = jnp.cumsum(betas)
betas = betas / betas[-1] # K-dimensional with betas[-1] = 1

mass_matrix = numpyro.param(
"{}_mass_matrix".format(self.prefix),
jnp.ones(self.latent_dim),
constraint=constraints.positive,
)
inv_mass_matrix = 0.5 / mass_matrix

init_z_loc = numpyro.param("{}_z_0_loc".format(self.prefix), self._init_latent)

if self.base_dist == "diagonal":
init_z_scale = numpyro.param(
"{}_z_0_scale".format(self.prefix),
jnp.full(self.latent_dim, self._init_scale),
constraint=constraints.positive,
)
base_z_dist = dist.Normal(init_z_loc, init_z_scale).to_event()
else:
scale_tril = numpyro.param(
"{}_scale_tril".format(self.prefix),
jnp.identity(self.latent_dim) * self._init_scale,
constraint=constraints.scaled_unit_lower_cholesky,
)
base_z_dist = dist.MultivariateNormal(init_z_loc, scale_tril=scale_tril)

z_0 = numpyro.sample(
"{}_z_0".format(self.prefix), base_z_dist, infer={"is_auxiliary": True}
)

base_z_dist_log_prob = base_z_dist.log_prob

momentum_dist = dist.Normal(0, mass_matrix).to_event()
eps = numpyro.sample(
"{}_momentum".format(self.prefix),
momentum_dist.expand((self.K,)).to_event().mask(False),
infer={"is_auxiliary": True},
)

def scan_body(carry, eps_beta):
eps, beta = eps_beta
eta = eta0 + eta_coeff * beta
eta = jnp.clip(eta, a_min=0.0, a_max=self.eta_max)
z_prev, v_prev, log_factor = carry
z_half = z_prev + v_prev * eta * inv_mass_matrix
q_grad = (1.0 - beta) * grad(base_z_dist_log_prob)(z_half)
p_grad = beta * grad(blocked_surrogate_model)(z_half)
v_hat = v_prev + eta * (q_grad + p_grad)
z = z_half + v_hat * eta * inv_mass_matrix
v = gamma * v_hat + jnp.sqrt(1 - gamma**2) * eps
delta_ke = momentum_dist.log_prob(v_prev) - momentum_dist.log_prob(v_hat)
log_factor = log_factor + delta_ke
return (z, v, log_factor), None

v_0 = eps[-1] # note the return value of scan doesn't depend on eps[-1]
(z, _, log_factor), _ = jax.lax.scan(scan_body, (z_0, v_0, 0.0), (eps, betas))

numpyro.factor("{}_factor".format(self.prefix), log_factor)

return z


def _subsample_model(model, *args, **kwargs):
data = kwargs.pop("_subsample_idx", {})
with handlers.substitute(data=data):
Expand Down
73 changes: 73 additions & 0 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
AutoMultivariateNormal,
AutoNormal,
AutoSemiDAIS,
AutoSurrogateLikelihoodDAIS,
)
from numpyro.infer.initialization import (
init_to_feasible,
Expand Down Expand Up @@ -859,3 +860,75 @@ def model2():
svi = SVI(model2, guide, optim.Adam(0.01), Trace_ELBO())
with pytest.raises(RuntimeError, match="are no local variables"):
svi.run(random.PRNGKey(0), 10)


def test_autosldais(
N=64, subsample_size=48, num_surrogate=32, D=3, num_steps=40000, num_samples=2000
):
def _model(X, Y):
theta = numpyro.sample(
"theta", dist.Normal(jnp.zeros(D), jnp.ones(D)).to_event(1)
)
with numpyro.plate("N", N, subsample_size=subsample_size):
X_batch = numpyro.subsample(X, event_dim=1)
Y_batch = numpyro.subsample(Y, event_dim=0)
numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch)

def _surrogate_model(X_surr, Y_surr):
theta = numpyro.sample(
"theta", dist.Normal(jnp.zeros(D), jnp.ones(D)).to_event(1)
)
omegas = numpyro.param(
"omegas",
2.0 * jnp.ones(num_surrogate),
constraint=dist.constraints.positive,
)

with numpyro.plate("N", num_surrogate), numpyro.handlers.scale(scale=omegas):
numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_surr.T), obs=Y_surr)

X = RandomState(0).randn(N, D)
X[:, 2] = X[:, 0] + X[:, 1]
logits = X[:, 0] - 0.5 * X[:, 1]
Y = dist.Bernoulli(logits=logits).sample(random.PRNGKey(0))

model = partial(_model, X, Y)
surrogate_model = partial(_surrogate_model, X[:num_surrogate], Y[:num_surrogate])

def _get_optim():
scheduler = piecewise_constant_schedule(
1.0e-3, {15 * 1000: 1.0e-4, 30 * 1000: 1.0e-5}
)
return optax.chain(
optax.scale_by_adam(), optax.scale_by_schedule(scheduler), optax.scale(-1.0)
)

guide = AutoSurrogateLikelihoodDAIS(
model, surrogate_model, K=3, eta_max=0.25, eta_init=0.005
)
svi_result = SVI(model, guide, _get_optim(), Trace_ELBO()).run(
random.PRNGKey(1), num_steps
)

samples = guide.sample_posterior(random.PRNGKey(2), svi_result.params)
assert samples["theta"].shape == (D,)

dais_elbo = Trace_ELBO(num_particles=num_samples).loss(
random.PRNGKey(0), svi_result.params, model, guide
)
dais_elbo = -dais_elbo.item()

def create_plates():
return numpyro.plate("N", N, subsample_size=subsample_size)

mf_guide = AutoNormal(model, create_plates=create_plates)
mf_svi_result = SVI(model, mf_guide, _get_optim(), Trace_ELBO()).run(
random.PRNGKey(0), num_steps
)

mf_elbo = Trace_ELBO(num_particles=num_samples).loss(
random.PRNGKey(1), mf_svi_result.params, model, mf_guide
)
mf_elbo = -mf_elbo.item()

assert dais_elbo > mf_elbo + 0.1

0 comments on commit 3f1a9d7

Please sign in to comment.