You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
While reviewing #989, I see that there are situations that we need to maintain extra variables, like moving average stats in BatchNorm. So it would be nice to introduce an additional primitive for this. Alternatively, we can add some options into deterministic primitive so that SVI knows and store it.
it’ll just be easier to pipe through the stats i’m interested in and store them in numpyro.deterministic
Note that this is not an issue for a stateful framework like PyTorch.
Here are solutions from other jax nn libraries:
flax introduces mutable argument to decide which fields in variables needed to be updated when evaluates the output. Here variables do not only contain params but also other mutable states like 'batch_stats'.
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
This function is equivalent to transform(), however it allows you to maintain and update internal state (e.g. ExponentialMovingAverage in BatchNorm) via get_state() and set_state():
counter, state = f.apply(params, state, None)
Those frameworks also provide additional ways to deal with rng states in stochastic layers like Dropout. It is easy to deal with: using numpyro.prng_key() primitive. But for stats states like above, currently, we don't have a solution yet.
The text was updated successfully, but these errors were encountered:
While reviewing #989, I see that there are situations that we need to maintain extra variables, like moving average stats in BatchNorm. So it would be nice to introduce an additional primitive for this. Alternatively, we can add some options into
deterministic
primitive so that SVI knows and store it.@martinjankowiak also requested something like this last year
Note that this is not an issue for a stateful framework like PyTorch.
Here are solutions from other jax nn libraries:
mutable
argument to decide which fields invariables
needed to be updated when evaluates the output. Herevariables
do not only containparams
but also other mutable states like'batch_stats'
.transform_with_state
Those frameworks also provide additional ways to deal with
rng
states in stochastic layers like Dropout. It is easy to deal with: usingnumpyro.prng_key()
primitive. But for stats states like above, currently, we don't have a solution yet.The text was updated successfully, but these errors were encountered: