-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
HMCECS in numpyro #724
Comments
Hi @OlaRonning, I think the current
Yeah, I agree. I think we can avoid it, by either using a handler or modifying your HMCECS implementation a bit. I am looking forward to seeing your PR. :) |
I hope I can help, by "user-defined subsample schemes" we would like to be able to imitate NUTS on the automatic set up of the parameter "subsample size" (same as they do with step size for example). Therefore the subsample size will change /will be corrected according to the likelihood. |
Unfortunately, I think that dynamic subsample size can only work with some masking mechanism. Kind of taking a subsample of max_subsample_size and masking out some log probabilities if the current subsample size is less than that max value. Could you provide some pseudo code so we can see if any additional mechanism is needed? I still believe that if something works with |
Hi @fehiepsi, The scheme is to block update a subsample of indices by gibbs sampling. In pseudo python it would be something like: n, _ = feats.shape # features
m = n // 3 # subsample size
g = 4 # number of blocks in subsample
u = randint(0,n+1).sample(m) # subsample
for _ in range(num_samples):
choosen_block = randint(0,m//g) # update blocks to induce correlation between u and u_new
u_new = u
u_new[chose_block] = randint(0,n).sample(m//g) # update entire block (u is partition into g blocks)
accpt_prob = likelihood(model(obs[u_new],feats[u_new,:]) / likelihood(model[u,:],model))
u = u_new if bernoulli(accept_prob)
hmc_update(model(obs[u], feats[u,:]) # evaulate grad_potential and potential on subsample The |
If you want to change subsample_size, I think you can do def model(..., subsample_size=None):
with numpyro.plate(..., subsample_size=subsample_size):
... (later, we can write a handler to do that job because In addition, I think that you want something like def model(full_data, full_feats, subsample_size=None):
with numpyro.plate(..., subsample_size=subsample_size) as idx:
data = full_data[idx]
feats = full_feats[idx]
... Not only that, IIUC you want to substitute |
Great that there is already a way to make The second model has the structure I was thinking of. We have to be able to substitute |
Hi @OlaRonning, I think you can replace your model by def augmented_model(*args, **kwargs):
subsample = kwargs.pop("_subsample")
return substitute(model, data={"plate_name": subsample})(*args, **kwargs) I think there are some requirements in the model to make this work. For example, the model code might look something like def model(x, y):
with numpyro.plate('p', num_data, dim=-2) as p:
x = x[p] # or x = numpyro.subsample(x, event_dim=1)
y = y[p] # or y = numpyro.subsample(y, event_dim=1)
numpyro.sample("obs", dist.Normal(x, 1), obs=y) Do you have any blockers else? I will try to address them if possible. |
I see, thanks for the examples. With your PR merged we have the tooling to make it work with We have a version, in an internal repo, that uses slicing of input parameters to do the subsampling. @LysSanzMoreta is currently porting it to a Numpyro fork. We will update and test with subsampling using a plate internally and make a subsequent PR once it's working. |
That is great! I am looking forward to your PRs. 👍 |
I've made two handlers to introduce the estimated likelihood computation into the HMCECS proto-type @fehiepsi wrote. The first is the class estimator(Messenger):
def __init__(self, fn, estimators, plate_sizes):
self.estimators = estimators
self.plate_sizes = plate_sizes
super(estimator, self).__init__(fn)
def process_message(self, msg):
if msg['type'] == 'sample' and msg['is_observed'] and msg['cond_indep_stack']: # <--- this conditional
log_prob = msg['fn'].log_prob
msg['scale'] = 1.
msg['fn'].log_prob = lambda *args, **kwargs: \
self.estimators[msg['name']](*args, name=msg['name'], z=_extract_params(msg['fn']), log_prob=log_prob,
sizes=self.plate_sizes[msg['cond_indep_stack'][0].name],
**kwargs) and class subsample_size(Messenger):
def __init__(self, fn, plate_sizes, rng_key=None):
super(subsample_size, self).__init__(fn)
self.plate_sizes = plate_sizes
self.rng_key = rng_key
def process_message(self, msg):
if msg['type'] == 'plate' and msg['args'] and msg["args"][0] > msg["args"][1]:
if msg['name'] in self.plate_sizes:
msg['args'] = self.plate_sizes[msg['name']]
msg['value'] = _subsample_fn(*msg['args'], self.rng_key) if msg["args"][1] < msg["args"][
0] else jnp.arange(msg["args"][0]) Regarding I have a working toy example here with a difference estimator and Taylor proxy. I'll clean and check it with some more extensive examples and then make a WIP PR today or tomorrow. |
That's awesome! FYI the naive version is implemented through the HMCGibbs interface. I guess you can just add some new arguments there to let users switch between different versions and to let users specify something like
I think you also need to check if plate names in
Does it change the size of subsample indices? I wonder if it will work with JAX...
Great idea! I was concerned about how to do this in a clean way and this seems to be a solution. There is still one small issue here is this will change the attribute |
The HMCGibbs interface looks great! I'll sync the aleatory fork and introduce likelihood estimators there.
I'll update this thread if I run into something; it seems more than sufficient from a quick inspection.
Yes, of course, thanks!
The Taylor expansion is computed w.r.t. reference params (found using fx. MAP. or MLE) using the entire dataset, so it is used to changes the subsample_size to be equal to the
Not sure what you mean; however, I think the wrapper |
I meant to only modify the message, rather than doing inplace modification for some of its values.
Make sense to me. You can freely do many things under |
Sorry for the delay; I found a problem with the Taylor expansion, which I'm trying to resolve. I will make a separate PR with the block update. |
@OlaRonning I'm just curious on your plan to add more stuffs to HMCECS. FYI we are going to release 0.5 with the latest jax/jaxlib release (in 1 or 2 weeks to close current PRs). Do you want to have something else in that release? |
Hi @fehiepsi. model(data, obs):
theta = sample('theta', Normal())
with plate('N', data.shape[0], subsample_size=10) as idx:
sample('obs', Bern(logit=data@theta), obs=obs) # derivative should be include the jnp computation While I suspect to solve this fairly soon, but I don't want to block your release schedule. So I'll wait to open more PR till after release .5 is out. |
Closing this issue with #905 merged. Thanks, @fehiepsi and @LysSanzMoreta for all the help! |
Hi,
@LysSanzMoreta and I are working on a project where we use HMCECS, and so made a model-specific implemented in Numpyro.
This implementation assumes the model to have the form:
and then implement HMC-in-Gibbs as a loop, following the algorithm:
We want to add HMCECS to Numpyro; however, we need feedback on a good design for handling the subsampling. The current version has a strong assumption on the parameters of the model (i.e. first and second params are obs and feats), which we would like to avoid; Pyro has plates that support subsamples, with user-defined subsample schemes. Are there plans for something similar in Numpyro, and if so what would it take to implement?
The text was updated successfully, but these errors were encountered: