Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add AutoSurrogateLikelihoodDAIS #1434

Merged
merged 11 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ Like HMC/NUTS, all remaining MCMC algorithms support enumeration over discrete l
- [AutoDelta](https://num.pyro.ai/en/latest/autoguide.html#autodelta) is used for computing point estimates via MAP (maximum a posteriori estimation). See [here](https://github.com/pyro-ppl/numpyro/blob/bbe1f879eede79eebfdd16dfc49c77c4d1fc727c/examples/zero_inflated_poisson.py#L101) for example usage.
- [AutoBNAFNormal](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoBNAFNormal) and [AutoIAFNormal](https://num.pyro.ai/en/latest/autoguide.html#autoiafnormal) offer flexible variational distributions parameterized by normalizing flows.
- [AutoDAIS](https://num.pyro.ai/en/latest/autoguide.html#autodais) is a powerful variational inference algorithm that leverages HMC. It can be a good choice for dealing with highly correlated posteriors but may be computationally expensive depending on the nature of the model.
- [AutoSurrogateLikelihoodDAIS](https://num.pyro.ai/en/latest/autoguide.html#autosurrogatelikelihooddais) is a powerful variational inference algorithm that leverages HMC and that supports data subsampling.
- [AutoSemiDAIS](https://num.pyro.ai/en/latest/autoguide.html#autosemidais) constructs a posterior approximation like [AutoDAIS](https://num.pyro.ai/en/latest/autoguide.html#autodais) for local latent variables but provides support for data subsampling during ELBO training by utilizing a parametric guide for global latent variables.
- [AutoLaplaceApproximation](https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoLaplaceApproximation) can be used to compute a Laplace approximation.

Expand Down
9 changes: 9 additions & 0 deletions docs/source/autoguide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ We provide a brief overview of the automatically generated guides available in N
* `AutoDelta <https://num.pyro.ai/en/latest/autoguide.html#autodelta>`_ is used for computing point estimates via MAP (maximum a posteriori estimation). See `here <https://github.com/pyro-ppl/numpyro/blob/bbe1f879eede79eebfdd16dfc49c77c4d1fc727c/examples/zero_inflated_poisson.py#L101>`_ for example usage.
* `AutoBNAFNormal <https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoBNAFNormal>`_ and `AutoIAFNormal <https://num.pyro.ai/en/latest/autoguide.html#autoiafnormal>`_ offer flexible variational distributions parameterized by normalizing flows.
* `AutoDAIS <https://num.pyro.ai/en/latest/autoguide.html#autodais>`_ is a powerful variational inference algorithm that leverages HMC. It can be a good choice for dealing with highly correlated posteriors but may be computationally expensive depending on the nature of the model.
* `AutoSurrogateLikelihoodDAIS <https://num.pyro.ai/en/latest/autoguide.html#autosurrogatelikelihooddais>`_ is a powerful variational inference algorithm that leverages HMC and that supports data subsampling.
* `AutoSemiDAIS <https://num.pyro.ai/en/latest/autoguide.html#autosemidais>`_ constructs a posterior approximation like `AutoDAIS <https://num.pyro.ai/en/latest/autoguide.html#autodais>`_ for local latent variables but provides support for data subsampling during ELBO training by utilizing a parametric guide for global latent variables.
* `AutoLaplaceApproximation <https://num.pyro.ai/en/latest/autoguide.html#numpyro.infer.autoguide.AutoLaplaceApproximation>`_ can be used to compute a Laplace approximation.

Expand Down Expand Up @@ -108,3 +109,11 @@ AutoSemiDAIS
:undoc-members:
:show-inheritance:
:member-order: bysource

AutoSurrogateLikelihoodDAIS
---------------------------
.. autoclass:: numpyro.infer.autoguide.AutoSurrogateLikelihoodDAIS
:members:
:undoc-members:
:show-inheritance:
:member-order: bysource
203 changes: 203 additions & 0 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"AutoIAFNormal",
"AutoDelta",
"AutoSemiDAIS",
"AutoSurrogateLikelihoodDAIS",
]


Expand Down Expand Up @@ -865,6 +866,208 @@ def _single_sample(_rng_key):
return _single_sample(rng_key)


class AutoSurrogateLikelihoodDAIS(AutoDAIS):
"""
This implementation of :class:`AutoSurrogateLikelihoodDAIS` provides a
mini-batchable family of variational distributions as described in [1].
It combines a user-provided surrogate likelihood with Differentiable Annealed
Importance Sampling (DAIS) [2, 3]. It is not applicable to models with local
latent variables (see :class:`AutoSemiDAIS`), but unlike :class:`AutoDAIS`, it
*can* be used in conjunction with data subsampling.

**Reference:**

1. *Surrogate likelihoods for variational annealed importance sampling*,
Martin Jankowiak, Du Phan
2. *MCMC Variational Inference via Uncorrected Hamiltonian Annealing*,
Tomas Geffner, Justin Domke
3. *Differentiable Annealed Importance Sampling and the Perils of Gradient Noise*,
Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse

Usage::

# logistic regression model for data {X, Y}
def model(X, Y):
theta = numpyro.sample(
"theta", dist.Normal(jnp.zeros(2), jnp.ones(2)).to_event(1)
)
with numpyro.plate("N", 100, subsample_size=10):
X_batch = numpyro.subsample(X, event_dim=1)
Y_batch = numpyro.subsample(Y, event_dim=0)
numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch)

# surrogate model defined by prior and surrogate likelihood.
# a convenient choice for specifying the latter is to compute the likelihood on
# a randomly chosen data subset (here {X_surr, Y_surr} of size 20) and then use
# handlers.scale to scale the log likelihood by a vector of learnable weights.
def surrogate_model(X_surr, Y_surr):
theta = numpyro.sample(
"theta", dist.Normal(jnp.zeros(2), jnp.ones(2)).to_event(1)
)
omegas = numpyro.param(
"omegas", 5.0 * jnp.ones(20), constraint=dist.constraints.positive
)
with numpyro.plate("N", 20), numpyro.handlers.scale(scale=omegas):
numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_surr.T), obs=Y_surr)

guide = AutoSurrogateLikelihoodDAIS(model, surrogate_model)
svi = SVI(model, guide, ...)

:param callable model: A NumPyro model.
:param callable surrogate_model: A NumPyro model that is used as a surrogate model
for guiding the HMC dynamics that define the variational distribution. In particular
`surrogate_model` should contain the same prior as `model` but should contain a
cheap-to-evaluate parametric ansatz for the likelihood. A simple ansatz for the latter
involves computing the likelihood for a fixed subset of the data and scaling the resulting
log likelihood by a learnable vector of positive weights. See the usage example above.
:param str prefix: A prefix that will be prefixed to all param internal sites.
:param int K: A positive integer that controls the number of HMC steps used.
Defaults to 4.
:param str base_dist: Controls whether the base Normal variational distribution
is parameterized by a "diagonal" covariance matrix or a full-rank covariance
matrix parameterized by a lower-triangular "cholesky" factor. Defaults to "diagonal".
:param float eta_init: The initial value of the step size used in HMC. Defaults
to 0.01.
:param float eta_max: The maximum value of the learnable step size used in HMC.
Defaults to 0.1.
:param float gamma_init: The initial value of the learnable damping factor used
during partial momentum refreshments in HMC. Defaults to 0.9.
:param callable init_loc_fn: A per-site initialization function.
See :ref:`init_strategy` section for available functions.
:param float init_scale: Initial scale for the standard deviation of
the base variational distribution for each (unconstrained transformed)
latent variable. Defaults to 0.1.
"""

def __init__(
self,
model,
surrogate_model,
*,
K=4,
eta_init=0.01,
eta_max=0.1,
gamma_init=0.9,
prefix="auto",
base_dist="diagonal",
init_loc_fn=init_to_uniform,
init_scale=0.1,
):
super().__init__(
model,
K=K,
eta_init=eta_init,
eta_max=eta_max,
gamma_init=gamma_init,
prefix=prefix,
init_loc_fn=init_loc_fn,
init_scale=init_scale,
base_dist=base_dist,
)

self.surrogate_model = surrogate_model

def _setup_prototype(self, *args, **kwargs):
AutoContinuous._setup_prototype(self, *args, **kwargs)

rng_key = numpyro.prng_key()

with numpyro.handlers.block():
(_, self._surrogate_potential_fn, _, _) = initialize_model(
rng_key,
self.surrogate_model,
init_strategy=self.init_loc_fn,
dynamic_args=False,
model_args=(),
model_kwargs={},
)

def _sample_latent(self, *args, **kwargs):
def blocked_surrogate_model(x):
x_unpack = self._unpack_latent(x)
with numpyro.handlers.block(hide_fn=lambda site: site["type"] != "param"):
martinjankowiak marked this conversation as resolved.
Show resolved Hide resolved
return -self._surrogate_potential_fn(x_unpack)

eta0 = numpyro.param(
"{}_eta0".format(self.prefix),
self.eta_init,
constraint=constraints.interval(0, self.eta_max),
)
eta_coeff = numpyro.param("{}_eta_coeff".format(self.prefix), 0.0)

gamma = numpyro.param(
"{}_gamma".format(self.prefix),
self.gamma_init,
constraint=constraints.interval(0, 1),
)
betas = numpyro.param(
"{}_beta_increments".format(self.prefix),
jnp.ones(self.K),
constraint=constraints.positive,
)
betas = jnp.cumsum(betas)
betas = betas / betas[-1] # K-dimensional with betas[-1] = 1

mass_matrix = numpyro.param(
"{}_mass_matrix".format(self.prefix),
jnp.ones(self.latent_dim),
constraint=constraints.positive,
)
inv_mass_matrix = 0.5 / mass_matrix

init_z_loc = numpyro.param("{}_z_0_loc".format(self.prefix), self._init_latent)

if self.base_dist == "diagonal":
init_z_scale = numpyro.param(
"{}_z_0_scale".format(self.prefix),
jnp.full(self.latent_dim, self._init_scale),
constraint=constraints.positive,
)
base_z_dist = dist.Normal(init_z_loc, init_z_scale).to_event()
else:
scale_tril = numpyro.param(
"{}_scale_tril".format(self.prefix),
jnp.identity(self.latent_dim) * self._init_scale,
constraint=constraints.scaled_unit_lower_cholesky,
)
base_z_dist = dist.MultivariateNormal(init_z_loc, scale_tril=scale_tril)

z_0 = numpyro.sample(
"{}_z_0".format(self.prefix), base_z_dist, infer={"is_auxiliary": True}
)

base_z_dist_log_prob = base_z_dist.log_prob

momentum_dist = dist.Normal(0, mass_matrix).to_event()
eps = numpyro.sample(
"{}_momentum".format(self.prefix),
momentum_dist.expand((self.K,)).to_event().mask(False),
infer={"is_auxiliary": True},
)

def scan_body(carry, eps_beta):
eps, beta = eps_beta
eta = eta0 + eta_coeff * beta
eta = jnp.clip(eta, a_min=0.0, a_max=self.eta_max)
z_prev, v_prev, log_factor = carry
z_half = z_prev + v_prev * eta * inv_mass_matrix
q_grad = (1.0 - beta) * grad(base_z_dist_log_prob)(z_half)
p_grad = beta * grad(blocked_surrogate_model)(z_half)
v_hat = v_prev + eta * (q_grad + p_grad)
z = z_half + v_hat * eta * inv_mass_matrix
v = gamma * v_hat + jnp.sqrt(1 - gamma**2) * eps
delta_ke = momentum_dist.log_prob(v_prev) - momentum_dist.log_prob(v_hat)
log_factor = log_factor + delta_ke
return (z, v, log_factor), None

v_0 = eps[-1] # note the return value of scan doesn't depend on eps[-1]
(z, _, log_factor), _ = jax.lax.scan(scan_body, (z_0, v_0, 0.0), (eps, betas))

numpyro.factor("{}_factor".format(self.prefix), log_factor)

return z


def _subsample_model(model, *args, **kwargs):
data = kwargs.pop("_subsample_idx", {})
with handlers.substitute(data=data):
Expand Down
71 changes: 71 additions & 0 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
AutoMultivariateNormal,
AutoNormal,
AutoSemiDAIS,
AutoSurrogateLikelihoodDAIS,
)
from numpyro.infer.initialization import (
init_to_feasible,
Expand Down Expand Up @@ -859,3 +860,73 @@ def model2():
svi = SVI(model2, guide, optim.Adam(0.01), Trace_ELBO())
with pytest.raises(RuntimeError, match="are no local variables"):
svi.run(random.PRNGKey(0), 10)


def test_autosldais(N=64, D=3, num_steps=45000, num_samples=2000):
def _model(X, Y):
theta = numpyro.sample(
"theta", dist.Normal(jnp.zeros(D), jnp.ones(D)).to_event(1)
)
with numpyro.plate("N", N, subsample_size=2 * N // 3):
X_batch = numpyro.subsample(X, event_dim=1)
Y_batch = numpyro.subsample(Y, event_dim=0)
numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch)

def _surrogate_model(X, Y):
theta = numpyro.sample(
"theta", dist.Normal(jnp.zeros(D), jnp.ones(D)).to_event(1)
)
omegas = numpyro.param(
"omegas", 2.0 * jnp.ones(N // 2), constraint=dist.constraints.positive
)

with numpyro.plate("N", N // 2), numpyro.handlers.scale(scale=omegas):
X_batch = numpyro.subsample(X, event_dim=1)
Y_batch = numpyro.subsample(Y, event_dim=0)
martinjankowiak marked this conversation as resolved.
Show resolved Hide resolved
numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch)

X = RandomState(0).randn(N, D)
X[:, 2] = X[:, 0] + X[:, 1]
logits = X[:, 0] - 0.5 * X[:, 1]
Y = dist.Bernoulli(logits=logits).sample(random.PRNGKey(0))

model = partial(_model, X, Y)
surrogate_model = partial(_surrogate_model, X[::2], Y[::2])

def _get_optim():
scheduler = piecewise_constant_schedule(
1.0e-3, {15 * 1000: 1.0e-4, 30 * 1000: 1.0e-5}
)
return optax.chain(
optax.scale_by_adam(), optax.scale_by_schedule(scheduler), optax.scale(-1.0)
)

guide = AutoSurrogateLikelihoodDAIS(
model, surrogate_model, K=3, eta_max=0.25, eta_init=0.005
)
svi_result = SVI(model, guide, _get_optim(), Trace_ELBO()).run(
random.PRNGKey(1), num_steps
)

samples = guide.sample_posterior(random.PRNGKey(2), svi_result.params)
assert samples["theta"].shape == (D,)

dais_elbo = Trace_ELBO(num_particles=num_samples).loss(
random.PRNGKey(0), svi_result.params, model, guide
)
dais_elbo = -dais_elbo.item()

def create_plates():
return numpyro.plate("N", N, subsample_size=2 * N // 3)

mf_guide = AutoNormal(model, create_plates=create_plates)
mf_svi_result = SVI(model, mf_guide, _get_optim(), Trace_ELBO()).run(
random.PRNGKey(0), num_steps
)

mf_elbo = Trace_ELBO(num_particles=num_samples).loss(
random.PRNGKey(0), mf_svi_result.params, model, mf_guide
)
mf_elbo = -mf_elbo.item()

assert dais_elbo > mf_elbo + 0.1