Skip to content

Commit

Permalink
clean-up test and add expose_types to block
Browse files Browse the repository at this point in the history
  • Loading branch information
martinjankowiak committed Jun 21, 2022
1 parent a726080 commit a488923
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 12 deletions.
5 changes: 4 additions & 1 deletion numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ class block(Messenger):
:param callable hide_fn: function which when given a dictionary containing
site-level metadata returns whether it should be blocked.
:param list hide: list of site names to hide.
:param list expose_types: list of site types to expose, e.g. `['param']`.
**Example:**
Expand All @@ -259,11 +260,13 @@ class block(Messenger):
>>> assert 'b' in trace_block_a
"""

def __init__(self, fn=None, hide_fn=None, hide=None):
def __init__(self, fn=None, hide_fn=None, hide=None, expose_types=None):
if hide_fn is not None:
self.hide_fn = hide_fn
elif hide is not None:
self.hide_fn = lambda msg: msg.get("name") in hide
elif expose_types is not None:
self.hide_fn = lambda msg: msg.get("type") not in expose_types
else:
self.hide_fn = lambda msg: True
super(block, self).__init__(fn)
Expand Down
24 changes: 13 additions & 11 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -862,36 +862,38 @@ def model2():
svi.run(random.PRNGKey(0), 10)


def test_autosldais(N=64, D=3, num_steps=45000, num_samples=2000):
def test_autosldais(
N=64, subsample_size=48, num_surrogate=32, D=3, num_steps=45000, num_samples=2000
):
def _model(X, Y):
theta = numpyro.sample(
"theta", dist.Normal(jnp.zeros(D), jnp.ones(D)).to_event(1)
)
with numpyro.plate("N", N, subsample_size=2 * N // 3):
with numpyro.plate("N", N, subsample_size=subsample_size):
X_batch = numpyro.subsample(X, event_dim=1)
Y_batch = numpyro.subsample(Y, event_dim=0)
numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch)

def _surrogate_model(X, Y):
def _surrogate_model(X_surr, Y_surr):
theta = numpyro.sample(
"theta", dist.Normal(jnp.zeros(D), jnp.ones(D)).to_event(1)
)
omegas = numpyro.param(
"omegas", 2.0 * jnp.ones(N // 2), constraint=dist.constraints.positive
"omegas",
2.0 * jnp.ones(num_surrogate),
constraint=dist.constraints.positive,
)

with numpyro.plate("N", N // 2), numpyro.handlers.scale(scale=omegas):
X_batch = numpyro.subsample(X, event_dim=1)
Y_batch = numpyro.subsample(Y, event_dim=0)
numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_batch.T), obs=Y_batch)
with numpyro.plate("N", num_surrogate), numpyro.handlers.scale(scale=omegas):
numpyro.sample("obs", dist.Bernoulli(logits=theta @ X_surr.T), obs=Y_surr)

X = RandomState(0).randn(N, D)
X[:, 2] = X[:, 0] + X[:, 1]
logits = X[:, 0] - 0.5 * X[:, 1]
Y = dist.Bernoulli(logits=logits).sample(random.PRNGKey(0))

model = partial(_model, X, Y)
surrogate_model = partial(_surrogate_model, X[::2], Y[::2])
surrogate_model = partial(_surrogate_model, X[:num_surrogate], Y[:num_surrogate])

def _get_optim():
scheduler = piecewise_constant_schedule(
Expand All @@ -917,15 +919,15 @@ def _get_optim():
dais_elbo = -dais_elbo.item()

def create_plates():
return numpyro.plate("N", N, subsample_size=2 * N // 3)
return numpyro.plate("N", N, subsample_size=subsample_size)

mf_guide = AutoNormal(model, create_plates=create_plates)
mf_svi_result = SVI(model, mf_guide, _get_optim(), Trace_ELBO()).run(
random.PRNGKey(0), num_steps
)

mf_elbo = Trace_ELBO(num_particles=num_samples).loss(
random.PRNGKey(0), mf_svi_result.params, model, mf_guide
random.PRNGKey(1), mf_svi_result.params, model, mf_guide
)
mf_elbo = -mf_elbo.item()

Expand Down

0 comments on commit a488923

Please sign in to comment.