From cc777e8148acee14a24d31799fd36b3dd5d567bd Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Tue, 3 Dec 2019 20:51:12 -0800 Subject: [PATCH] Minor changes to infer.util for 0.2.2 (#487) * Minor changes to infer.util for 0.2.2 * fix test * fix test; address comment * fix invocation --- examples/stochastic_volatility.py | 2 +- numpyro/contrib/autoguide/__init__.py | 5 +- numpyro/infer/mcmc.py | 7 +-- numpyro/infer/util.py | 67 +++++++++++++++------------ test/test_infer_util.py | 20 ++++---- test/test_mcmc.py | 2 +- 6 files changed, 59 insertions(+), 44 deletions(-) diff --git a/examples/stochastic_volatility.py b/examples/stochastic_volatility.py index 1a21f0914..f130aac5d 100644 --- a/examples/stochastic_volatility.py +++ b/examples/stochastic_volatility.py @@ -82,7 +82,7 @@ def main(args): _, fetch = load_dataset(SP500, shuffle=False) dates, returns = fetch() init_rng_key, sample_rng_key = random.split(random.PRNGKey(args.rng_seed)) - init_params, potential_fn, constrain_fn = initialize_model(init_rng_key, model, returns) + init_params, potential_fn, constrain_fn = initialize_model(init_rng_key, model, model_args=(returns,)) init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS') hmc_state = init_kernel(init_params, args.num_warmup, rng_key=sample_rng_key) hmc_states = fori_collect(args.num_warmup, args.num_warmup + args.num_samples, sample_kernel, hmc_state, diff --git a/numpyro/contrib/autoguide/__init__.py b/numpyro/contrib/autoguide/__init__.py index 8ac8e82d2..97c6402ae 100644 --- a/numpyro/contrib/autoguide/__init__.py +++ b/numpyro/contrib/autoguide/__init__.py @@ -136,9 +136,10 @@ def __init__(self, model, prefix="auto", init_strategy=init_to_uniform()): def _setup_prototype(self, *args, **kwargs): super(AutoContinuous, self)._setup_prototype(*args, **kwargs) rng_key = numpyro.sample("_{}_rng_key_init".format(self.prefix), dist.PRNGIdentity()) - init_params, _ = handlers.block(find_valid_initial_params)(rng_key, self.model, *args, + init_params, _ = handlers.block(find_valid_initial_params)(rng_key, self.model, init_strategy=self.init_strategy, - **kwargs) + model_args=args, + model_kwargs=kwargs) self._inv_transforms = {} self._has_transformed_dist = False unconstrained_sites = {} diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 02549a43d..2a9572c45 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -151,7 +151,7 @@ def hmc(potential_fn=None, potential_fn_gen=None, kinetic_fn=None, algo='NUTS'): ... return numpyro.sample('y', dist.Bernoulli(logits=(coefs * data + intercept).sum(-1)), obs=labels) >>> >>> init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(0), - ... model, data, labels) + ... model, model_args=(data, labels,)) >>> init_kernel, sample_kernel = hmc(potential_fn, algo='NUTS') >>> hmc_state = init_kernel(init_params, ... trajectory_length=10, @@ -495,10 +495,11 @@ def init(self, rng_key, num_warmup, init_params=None, model_args=(), model_kwarg ' `potential_fn`.') # Find valid initial params if self._model and not init_params: - init_params, is_valid = find_valid_initial_params(rng_key, self._model, *model_args, + init_params, is_valid = find_valid_initial_params(rng_key, self._model, init_strategy=self._init_strategy, param_as_improper=True, - **model_kwargs) + model_args=model_args, + model_kwargs=model_kwargs) if not_jax_tracer(is_valid): if device_get(~np.all(is_valid)): raise RuntimeError("Cannot find valid initial parameters. " diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index c549cdbcd..51129e819 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -32,11 +32,12 @@ def log_density(model, model_args, model_kwargs, params, skip_dist_transforms=False): """ - Computes log of joint density for the model given latent values ``params``. + (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given + latent values ``params``. :param model: Python callable containing NumPyro primitives. :param tuple model_args: args provided to the model. - :param dict model_kwargs`: kwargs provided to the model. + :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of current parameter values keyed by site name. :param bool skip_dist_transforms: whether to compute log probability of a site @@ -76,8 +77,9 @@ def log_density(model, model_args, model_kwargs, params, skip_dist_transforms=Fa def transform_fn(transforms, params, invert=False): """ - Callable that applies a transformation from the `transforms` dict to values in the - `params` dict and returns the transformed values keyed on the same names. + (EXPERIMENTAL INTERFACE) Callable that applies a transformation from the `transforms` + dict to values in the `params` dict and returns the transformed values keyed on + the same names. :param transforms: Dictionary of transforms keyed by names. Names in `transforms` and `params` should align. @@ -93,17 +95,18 @@ def transform_fn(transforms, params, invert=False): def constrain_fn(model, transforms, model_args, model_kwargs, params): """ - Gets value at each latent site in `model` given unconstrained parameters `params`. - The `transforms` is used to transform these unconstrained parameters to base values - of the corresponding priors in `model`. If a prior is a transformed distribution, - the corresponding base value lies in the support of base distribution. Otherwise, - the base value lies in the support of the distribution. + (EXPERIMENTAL INTERFACE) Gets value at each latent site in `model` given + unconstrained parameters `params`. The `transforms` is used to transform these + unconstrained parameters to base values of the corresponding priors in `model`. + If a prior is a transformed distribution, the corresponding base value lies in + the support of base distribution. Otherwise, the base value lies in the support + of the distribution. :param model: a callable containing NumPyro primitives. - :param tuple model_args: args provided to the model. - :param dict model_kwargs: kwargs provided to the model. :param dict transforms: dictionary of transforms keyed by names. Names in `transforms` and `params` should align. + :param tuple model_args: args provided to the model. + :param dict model_kwargs: kwargs provided to the model. :param dict params: dictionary of unconstrained values keyed by site names. :return: `dict` of transformed params. @@ -116,16 +119,16 @@ def constrain_fn(model, transforms, model_args, model_kwargs, params): def potential_energy(model, inv_transforms, model_args, model_kwargs, params): """ - Computes potential energy of a model given unconstrained params. + (EXPERIMENTAL INTERFACE) Computes potential energy of a model given unconstrained params. The `inv_transforms` is used to transform these unconstrained parameters to base values of the corresponding priors in `model`. If a prior is a transformed distribution, the corresponding base value lies in the support of base distribution. Otherwise, the base value lies in the support of the distribution. :param model: a callable containing NumPyro primitives. - :param tuple model_args: args provided to the model. - :param dict model_kwargs`: kwargs provided to the model. :param dict inv_transforms: dictionary of transforms keyed by names. + :param tuple model_args: args provided to the model. + :param dict model_kwargs: kwargs provided to the model. :param dict params: unconstrained parameters of `model`. :return: potential energy given unconstrained parameters. """ @@ -268,8 +271,11 @@ def init_to_value(values): return partial(_init_to_value, values=values) -def find_valid_initial_params(rng_key, model, *model_args, init_strategy=init_to_uniform(), - param_as_improper=False, **model_kwargs): +def find_valid_initial_params(rng_key, model, + init_strategy=init_to_uniform(), + param_as_improper=False, + model_args=(), + model_kwargs=None): """ (EXPERIMENTAL INTERFACE) Given a model with Pyro primitives, returns an initial valid unconstrained value for all the parameters. This function also returns an @@ -281,11 +287,11 @@ def find_valid_initial_params(rng_key, model, *model_args, init_strategy=init_to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. - :param `*model_args`: args provided to the model. :param callable init_strategy: a per-site initialization function. :param bool param_as_improper: a flag to decide whether to consider sites with `param` statement as sites with improper priors. - :param `**model_kwargs`: kwargs provided to the model. + :param tuple model_args: args provided to the model. + :param dict model_kwargs: kwargs provided to the model. :return: tuple of (`init_params`, `is_valid`). """ init_strategy = jax.partial(init_strategy, skip_param=not param_as_improper) @@ -416,8 +422,11 @@ def constrain_fun(*args, **kwargs): return potential_fn, constrain_fun -def initialize_model(rng_key, model, *model_args, init_strategy=init_to_uniform(), - dynamic_args=False, **model_kwargs): +def initialize_model(rng_key, model, + init_strategy=init_to_uniform(), + dynamic_args=False, + model_args=(), + model_kwargs=None): """ (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn` and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood @@ -427,30 +436,33 @@ def initialize_model(rng_key, model, *model_args, init_strategy=init_to_uniform( sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. - :param `*model_args`: args provided to the model. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. - :param `**model_kwargs`: kwargs provided to the model. + :param tuple model_args: args provided to the model. + :param dict model_kwargs: kwargs provided to the model. :return: tuple of (`init_params`, `potential_fn`, `constrain_fn`), `init_params` are values from the prior used to initiate MCMC, `constrain_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support. """ + if model_kwargs is None: + model_kwargs = {} potential_fun, constrain_fun = get_potential_fn(rng_key if rng_key.ndim == 1 else rng_key[0], model, dynamic_args=dynamic_args, model_args=model_args, model_kwargs=model_kwargs) - init_params, is_valid = find_valid_initial_params(rng_key, model, *model_args, + init_params, is_valid = find_valid_initial_params(rng_key, model, init_strategy=init_strategy, param_as_improper=True, - **model_kwargs) + model_args=model_args, + model_kwargs=model_kwargs) if not_jax_tracer(is_valid): if device_get(~np.all(is_valid)): @@ -559,11 +571,8 @@ def get_samples(self, rng_key, *args, **kwargs): def log_likelihood(model, posterior_samples, *args, **kwargs): """ - Returns log likelihood at observation nodes of model, given samples of all latent variables. - - .. warning:: - The interface for the `log_likelihood` function is experimental, and - might change in the future. + (EXPERIMENTAL INTERFACE) Returns log likelihood at observation nodes of model, + given samples of all latent variables. :param model: Python callable containing Pyro primitives. :param dict posterior_samples: dictionary of samples from the posterior. diff --git a/test/test_infer_util.py b/test/test_infer_util.py index 490f1c50f..0965ff05f 100644 --- a/test/test_infer_util.py +++ b/test/test_infer_util.py @@ -192,11 +192,13 @@ def model(data): ]) rng_keys = random.split(random.PRNGKey(1), 2) - init_params, _, _ = initialize_model(rng_keys, model, count_data, - init_strategy=init_strategy) + init_params, _, _ = initialize_model(rng_keys, model, + init_strategy=init_strategy, + model_args=(count_data,)) for i in range(2): - init_params_i, _, _ = initialize_model(rng_keys[i], model, count_data, - init_strategy=init_strategy) + init_params_i, _, _ = initialize_model(rng_keys[i], model, + init_strategy=init_strategy, + model_args=(count_data,)) for name, p in init_params.items(): # XXX: the result is equal if we disable fast-math-mode assert_allclose(p[i], init_params_i[name], atol=1e-6) @@ -219,11 +221,13 @@ def model(data): data = dist.Categorical(true_probs).sample(random.PRNGKey(1), (2000,)) rng_keys = random.split(random.PRNGKey(1), 2) - init_params, _, _ = initialize_model(rng_keys, model, data, - init_strategy=init_strategy) + init_params, _, _ = initialize_model(rng_keys, model, + init_strategy=init_strategy, + model_args=(data,)) for i in range(2): - init_params_i, _, _ = initialize_model(rng_keys[i], model, data, - init_strategy=init_strategy) + init_params_i, _, _ = initialize_model(rng_keys[i], model, + init_strategy=init_strategy, + model_args=(data,)) for name, p in init_params.items(): # XXX: the result is equal if we disable fast-math-mode assert_allclose(p[i], init_params_i[name], atol=1e-6) diff --git a/test/test_mcmc.py b/test/test_mcmc.py index 2a43edb9e..cc192599a 100644 --- a/test/test_mcmc.py +++ b/test/test_mcmc.py @@ -459,7 +459,7 @@ def model(data): true_probs = np.array([0.9, 0.1]) data = dist.Bernoulli(true_probs).sample(random.PRNGKey(1), (1000, 2)) - init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, data) + init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, model_args=(data,)) init_kernel, sample_kernel = hmc(potential_fn, algo=algo) hmc_state = init_kernel(init_params, trajectory_length=1.,