Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate the model when cannot find valid initial params. #733

Merged
merged 4 commits into from
Sep 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ def __init__(self, batch_shape=(), event_shape=(), validate_args=None):
is_valid = jnp.all(constraint(getattr(self, param)))
if not_jax_tracer(is_valid):
if not is_valid:
raise ValueError("The parameter {} has invalid values".format(param))
raise ValueError("{} distribution got invalid {} parameter.".format(
self.__class__.__name__, param))
super(Distribution, self).__init__()

@property
Expand Down
15 changes: 15 additions & 0 deletions numpyro/infer/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,21 @@ def initialize_model(rng_key, model,

if not_jax_tracer(is_valid):
if device_get(~jnp.all(is_valid)):
with numpyro.validation_enabled(), trace() as tr:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more informative warning / error message is definitely needed. I am thinking that can we simply run initialize_model with validation_enabled (I wouldn't expect that to add any material overhead)? Is the resulting warning message not informative enough?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is simpler (I also don't worry about the overhead) but there are two issues with that:

  • validation is only useful for the first try (under jax loop, we can't prompt the warning/error for the later tries)
  • displaying the warning message for the first try might not be useful for users when we can find a valid one in a later try

What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining, @fehiepsi, both of your points make a lot of sense.

# validate parameters
substituted_model(*model_args, **model_kwargs)
# validate values
for site in tr.values():
if site['type'] == 'sample':
with warnings.catch_warnings(record=True) as ws:
site['fn']._validate_sample(site['value'])
if len(ws) > 0:
for w in ws:
# at site information to the warning message
w.message.args = ("Site {}: {}".format(site["name"], w.message.args[0]),) \
+ w.message.args[1:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does a sample warning message look like?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing, @neerajprad! I'm a bit busy today but will run this again to display the warning message tomorrow. Here I just want to add site information to the warning message, because not many users know how to use warnings to turn a warning to an error to debug.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad the following model

import numpyro
import numpy as np
def model():
    x = numpyro.sample("x", numpyro.distributions.Normal())
    numpyro.sample("obs", numpyro.distributions.Normal(x), obs=float('nan'))
mcmc = numpyro.infer.MCMC(numpyro.infer.NUTS(model), 10, 10)
mcmc.run(np.array([0, 0], dtype='uint32'))

gives the warning

UserWarning: Site obs: Out-of-support values provided to log prob method. The value argument should be within the support.

warnings.showwarning(w.message, w.category, w.filename, w.lineno,
file=w.file, line=w.line)
raise RuntimeError("Cannot find valid initial parameters. Please check your model again.")
return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn, postprocess_fn, model_trace)

Expand Down