Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow propagate mutable states in SVI #990

Closed
fehiepsi opened this issue Apr 2, 2021 · 0 comments · Fixed by #1016
Closed

Allow propagate mutable states in SVI #990

fehiepsi opened this issue Apr 2, 2021 · 0 comments · Fixed by #1016
Assignees
Labels
enhancement New feature or request
Milestone

Comments

@fehiepsi
Copy link
Member

fehiepsi commented Apr 2, 2021

While reviewing #989, I see that there are situations that we need to maintain extra variables, like moving average stats in BatchNorm. So it would be nice to introduce an additional primitive for this. Alternatively, we can add some options into deterministic primitive so that SVI knows and store it.

@martinjankowiak also requested something like this last year

it’ll just be easier to pipe through the stats i’m interested in and store them in numpyro.deterministic

Note that this is not an issue for a stateful framework like PyTorch.

Here are solutions from other jax nn libraries:

  • flax introduces mutable argument to decide which fields in variables needed to be updated when evaluates the output. Here variables do not only contain params but also other mutable states like 'batch_stats'.
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
  • haiku introduces transform_with_state

This function is equivalent to transform(), however it allows you to maintain and update internal state (e.g. ExponentialMovingAverage in BatchNorm) via get_state() and set_state():

counter, state = f.apply(params, state, None)

Those frameworks also provide additional ways to deal with rng states in stochastic layers like Dropout. It is easy to deal with: using numpyro.prng_key() primitive. But for stats states like above, currently, we don't have a solution yet.

@fehiepsi fehiepsi added the enhancement New feature or request label Apr 2, 2021
@fehiepsi fehiepsi added this to the 0.7 milestone Apr 10, 2021
@fehiepsi fehiepsi self-assigned this Apr 18, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant