Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate the model when cannot find valid initial params. #733

Merged
merged 4 commits into from
Sep 23, 2020

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Sep 10, 2020

Resolves #731. As explained there, it might be tricky to find bugs when the inference cannot find valid initial parameters. With this PR, when that happens, we can recognize at which site, things go wrong.

@rim30 could you run your code with this branch to see where causes the problem?

TODO

  • incorporate init_strategy because some distributions such Improper does not have sample method.

@fehiepsi fehiepsi added the WIP label Sep 11, 2020
@rim30
Copy link

rim30 commented Sep 18, 2020

Hi @fehiepsi ,

Sorry for the massive delay, I was a bit busy until today. So I installed numpyro from your branch (https://github.com/fehiepsi/numpyro.git@validate) and got this error:

numpyro_error

The time series that i am using is: [1146., 488., 753., 583., 553., 832., 807., 875.,
945., 795., 862., 1322., 890., 911., 990., 791.,
910., 957., 838., 956., 920., 945., 1192., 1921.,
987., 907., 762., 859., 843., 804., 785., 942.,
822., 727.] which you can clearly see it is not negative.

quick note, i used this simpler version of sgt which i modelled after the one from the numpyro website:

def simple_sgt(y, seasonality, future=0):
    # heuristically, standard derivation of Cauchy prior depends on
    # the max value of data
    cauchy_sd = jnp.max(y) / 150

    # NB: priors' parameters are taken from
    # https://github.com/cbergmeir/Rlgt/blob/master/Rlgt/R/rlgtcontrol.R
    nu = numpyro.sample("nu", dist.Uniform(2, 20))
    powx = numpyro.sample("powx", dist.Uniform(0, 1))
    sigma = numpyro.sample("sigma", dist.HalfCauchy(cauchy_sd))
    offset_sigma = numpyro.sample(
        "offset_sigma", dist.TruncatedCauchy(low=1e-10, loc=1e-10, scale=cauchy_sd)
    )

    coef_trend = numpyro.sample("coef_trend", dist.Cauchy(0, cauchy_sd))
    pow_trend_beta = numpyro.sample("pow_trend_beta", dist.Beta(1, 1))
    # pow_trend takes values from -0.5 to 1
    pow_trend = 1.5 * pow_trend_beta - 0.5
    #pow_season = numpyro.sample("pow_season", dist.Beta(1, 1))

    level_sm = numpyro.sample("level_sm", dist.Beta(1, 2))
    s_sm = numpyro.sample("s_sm", dist.Uniform(0, 1))
    init_s = numpyro.sample("init_s", dist.Cauchy(0, y[:seasonality] * 0.3))

    def transition_fn(carry, t):
        #level, s, moving_sum = carry
        level, s = carry
        #season = s[0] * level ** pow_season
        #exp_val = level + coef_trend * level ** pow_trend + season
        exp_val = (level + coef_trend * level ** pow_trend) * s[0]
        exp_val = jnp.clip(exp_val, a_min=0)
        # use expected vale when forecasting
        y_t = jnp.where(t >= N, exp_val, y[t])

        #moving_sum = (
        #    moving_sum + y[t] - jnp.where(t >= seasonality, y[t - seasonality], 0.0)
        #)
        #level_p = jnp.where(t >= seasonality, moving_sum / seasonality, y_t - season)
        #level = level_sm * level_p + (1 - level_sm) * level
        level = level_sm * y_t / s[0] + (1 - level_sm) * level
        level = jnp.clip(level, a_min=0)

        #new_s = (s_sm * (y_t - level) / season + (1 - s_sm)) * s[0]
        new_s = (s_sm * y_t / level + (1 - s_sm)) * s[0]
        # repeat s when forecasting
        new_s = jnp.where(t >= N, s[0], new_s)
        s = jnp.concatenate([s[1:], new_s[None]], axis=0)

        omega = sigma * exp_val ** powx + offset_sigma
        y_ = numpyro.sample("y", dist.StudentT(nu, exp_val, omega))

        #return (level, s, moving_sum), y_
        return (level, s), y_

    N = y.shape[0]
    level_init = y[0]
    s_init = jnp.concatenate([init_s[1:], init_s[:1]], axis=0)
    #moving_sum = level_init
    with numpyro.handlers.condition(data={"y": y[1:]}):
        _, ys = scan(
            transition_fn, (level_init, s_init), jnp.arange(1, N + future)
        )
    if future > 0:
        numpyro.deterministic("y_forecast", ys[-future:])

@fehiepsi
Copy link
Member Author

fehiepsi commented Sep 18, 2020

@rim30 How about replacing exp_val = jnp.clip(exp_val, a_min=0) by

exp_val = jnp.clip(exp_val, a_min=1e-30, a_max=1e38)

? Our validation code hardly detects numerical issues...

@@ -427,6 +427,21 @@ def initialize_model(rng_key, model,

if not_jax_tracer(is_valid):
if device_get(~jnp.all(is_valid)):
with numpyro.validation_enabled(), trace() as tr:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A more informative warning / error message is definitely needed. I am thinking that can we simply run initialize_model with validation_enabled (I wouldn't expect that to add any material overhead)? Is the resulting warning message not informative enough?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is simpler (I also don't worry about the overhead) but there are two issues with that:

  • validation is only useful for the first try (under jax loop, we can't prompt the warning/error for the later tries)
  • displaying the warning message for the first try might not be useful for users when we can find a valid one in a later try

What do you think?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for explaining, @fehiepsi, both of your points make a lot of sense.

for w in ws:
# at site information to the warning message
w.message.args = ("Site {}: {}".format(site["name"], w.message.args[0]),) \
+ w.message.args[1:]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does a sample warning message look like?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reviewing, @neerajprad! I'm a bit busy today but will run this again to display the warning message tomorrow. Here I just want to add site information to the warning message, because not many users know how to use warnings to turn a warning to an error to debug.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@neerajprad the following model

import numpyro
import numpy as np
def model():
    x = numpyro.sample("x", numpyro.distributions.Normal())
    numpyro.sample("obs", numpyro.distributions.Normal(x), obs=float('nan'))
mcmc = numpyro.infer.MCMC(numpyro.infer.NUTS(model), 10, 10)
mcmc.run(np.array([0, 0], dtype='uint32'))

gives the warning

UserWarning: Site obs: Out-of-support values provided to log prob method. The value argument should be within the support.

@neerajprad
Copy link
Member

LGTM. Do you think this handles most of the forum questions you have been getting, or were they due to other numerical issues outside of distributions?

@fehiepsi
Copy link
Member Author

This helps detect some issues in the forum, one for data not belong to the support and one for wrong parameter.

@neerajprad neerajprad merged commit 7f61de0 into pyro-ppl:master Sep 23, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add a better error message for how to debug "Cannot find valid initial parameters" issue
3 participants