Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/master' into dais-vae1
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed Oct 22, 2023
2 parents d211a73 + 2416eb9 commit 84d8a7e
Show file tree
Hide file tree
Showing 10 changed files with 26 additions and 23 deletions.
5 changes: 2 additions & 3 deletions examples/stein_bnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
from sklearn.model_selection import train_test_split

import jax
from jax import random
import jax.numpy as jnp

Expand Down Expand Up @@ -195,9 +196,7 @@ def main(args):


if __name__ == "__main__":
from jax.config import config

config.update("jax_debug_nans", True)
jax.config.update("jax_debug_nans", True)

parser = argparse.ArgumentParser()
parser.add_argument("--subsample-size", type=int, default=100)
Expand Down
2 changes: 2 additions & 0 deletions numpyro/contrib/control_flow/cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ def _subs_wrapper(subs_map, site):
if isinstance(subs_map, dict) and site["name"] in subs_map:
return subs_map[site["name"]]
elif callable(subs_map):
if site["type"] == "deterministic":
return subs_map(site)
rng_key = site["kwargs"].get("rng_key")
subs_map = (
handlers.seed(subs_map, rng_seed=rng_key)
Expand Down
3 changes: 1 addition & 2 deletions numpyro/contrib/nested_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
Model,
NestedSamplerResults,
Prior,
PriorModelGen,
TerminationCondition,
plot_cornerplot,
plot_diagnostics,
Expand Down Expand Up @@ -243,7 +242,7 @@ def run(self, rng_key, *args, **kwargs):
loglik_fn = local_dict["loglik_fn"]

# use NestedSampler with identity prior chain
def prior_model() -> PriorModelGen:
def prior_model():
params = []
for name in param_names:
shape = prototype_trace[name]["fn"].shape()
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/batch_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def _promote_batch_shape_masked(d: MaskedDistribution):
def _promote_batch_shape_independent(d: Independent):
new_self = copy.copy(d)
new_base_dist = promote_batch_shape(d.base_dist)
new_self._batch_shape = new_base_dist.batch_shape[: d.event_dim]
new_self._batch_shape = new_base_dist.batch_shape[: -d.event_dim]
new_self.base_dist = new_base_dist
return new_self

Expand Down
6 changes: 3 additions & 3 deletions numpyro/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,9 +794,9 @@ def __init__(self, fn=None, data=None, substitute_fn=None):
super(substitute, self).__init__(fn)

def process_message(self, msg):
if (msg["type"] not in ("sample", "param", "mutable", "plate")) or msg.get(
"_control_flow_done", False
):
if (
msg["type"] not in ("sample", "param", "mutable", "plate", "deterministic")
) or msg.get("_control_flow_done", False):
if msg["type"] == "control_flow":
if self.data is not None:
msg["kwargs"]["substitute_stack"].append(("substitute", self.data))
Expand Down
10 changes: 5 additions & 5 deletions numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __call__(self, *args, **kwargs):
raise NotImplementedError

@abstractmethod
def sample_posterior(self, rng_key, params, sample_shape=()):
def sample_posterior(self, rng_key, params, *, sample_shape=()):
"""
Generate samples from the approximate posterior over the latent
sites in the model.
Expand Down Expand Up @@ -444,7 +444,7 @@ def _constrain(self, latent_samples):
else:
return self._postprocess_fn(latent_samples)

def sample_posterior(self, rng_key, params, sample_shape=()):
def sample_posterior(self, rng_key, params, *, sample_shape=()):
locs = {k: params["{}_{}_loc".format(k, self.prefix)] for k in self._init_locs}
scales = {k: params["{}_{}_scale".format(k, self.prefix)] for k in locs}
with handlers.seed(rng_seed=rng_key):
Expand Down Expand Up @@ -776,7 +776,7 @@ def get_posterior(self, params):
transform = self.get_transform(params)
return dist.TransformedDistribution(base_dist, transform)

def sample_posterior(self, rng_key, params, sample_shape=()):
def sample_posterior(self, rng_key, params, *, sample_shape=()):
latent_sample = handlers.substitute(
handlers.seed(self._sample_latent, rng_key), params
)(sample_shape=sample_shape)
Expand Down Expand Up @@ -965,7 +965,7 @@ def scan_body(carry, eps_beta):

return z

def sample_posterior(self, rng_key, params, sample_shape=()):
def sample_posterior(self, rng_key, params, *, sample_shape=()):
def _single_sample(_rng_key):
latent_sample = handlers.substitute(
handlers.seed(self._sample_latent, _rng_key), params
Expand Down Expand Up @@ -1988,7 +1988,7 @@ def get_posterior(self, params):
transform = self.get_transform(params)
return dist.MultivariateNormal(transform.loc, scale_tril=transform.scale_tril)

def sample_posterior(self, rng_key, params, sample_shape=()):
def sample_posterior(self, rng_key, params, *, sample_shape=()):
latent_sample = self.get_posterior(params).sample(rng_key, sample_shape)
return self._unpack_and_constrain(latent_sample, params)

Expand Down
12 changes: 7 additions & 5 deletions numpyro/ops/provenance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from jax.api_util import flatten_fun, shaped_abstractify
import jax.core as core
from jax.experimental.pjit import pjit_p
import jax.util as util

try:
import jax.extend.linear_util as lu
except ImportError:
import jax.linear_util as lu

from jax.interpreters.partial_eval import trace_to_jaxpr_dynamic
from jax.interpreters.pxla import xla_pmap_p
import jax.numpy as jnp
Expand Down Expand Up @@ -40,7 +42,7 @@ def eval_provenance(fn, **kwargs):
args, in_tree = jax.tree_util.tree_flatten(((), kwargs))
wrapped_fun, out_tree = flatten_fun(lu.wrap_init(fn), in_tree)
# Abstract eval to get output pytree
avals = core.safe_map(shaped_abstractify, args)
avals = util.safe_map(shaped_abstractify, args)
# XXX: we split out the process of abstract evaluation and provenance tracking
# for simplicity. In principle, they can be merged so that we only need to walk
# through the equations once.
Expand Down Expand Up @@ -81,14 +83,14 @@ def write(v, p):
return
env[v] = read(v) | p

core.safe_map(write, jaxpr.invars, provenance_inputs)
util.safe_map(write, jaxpr.invars, provenance_inputs)
for eqn in jaxpr.eqns:
provenance_inputs = core.safe_map(read, eqn.invars)
provenance_inputs = util.safe_map(read, eqn.invars)
rule = track_deps_rules.get(eqn.primitive, _default_track_deps_rules)
provenance_outputs = rule(eqn, provenance_inputs)
core.safe_map(write, eqn.outvars, provenance_outputs)
util.safe_map(write, eqn.outvars, provenance_outputs)

return core.safe_map(read, jaxpr.outvars)
return util.safe_map(read, jaxpr.outvars)


track_deps_rules = {}
Expand Down
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import os

from jax.config import config
from jax import config

from numpyro.util import set_rng_seed

Expand Down
6 changes: 3 additions & 3 deletions test/contrib/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def transition_fn(c, val):


def test_scan_plate_mask():
def model(y=None, T=10):
def model(y=None, T=12):
def transition(carry, y_curr):
x_prev, t = carry
with numpyro.plate("N", 10, dim=-1):
Expand All @@ -237,7 +237,7 @@ def transition(carry, y_curr):
return (x, y)

with numpyro.handlers.seed(rng_seed=0):
model_density, model_trace = log_density(model, (None, 10), {}, {})
model_density, model_trace = log_density(model, (None, 12), {}, {})
assert model_density
assert model_trace["x"]["fn"].batch_shape == (10,)
assert model_trace["x"]["fn"].batch_shape == (12, 10)
assert model_trace["x"]["fn"].event_shape == (3,)
1 change: 1 addition & 0 deletions test/infer/test_autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ def model(y=None):
random.PRNGKey(0), params, sample_shape=(1000,)
)

posterior_samples.pop("z")
predictive = Predictive(model, posterior_samples, params=params)
predictive_samples = predictive(random.PRNGKey(0), y_test)

Expand Down

0 comments on commit 84d8a7e

Please sign in to comment.