Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HMCECS in numpyro #724

Closed
OlaRonning opened this issue Sep 7, 2020 · 18 comments
Closed

HMCECS in numpyro #724

OlaRonning opened this issue Sep 7, 2020 · 18 comments
Labels
enhancement New feature or request

Comments

@OlaRonning
Copy link
Member

OlaRonning commented Sep 7, 2020

Hi,

@LysSanzMoreta and I are working on a project where we use HMCECS, and so made a model-specific implemented in Numpyro.
This implementation assumes the model to have the form:

def model(obs, feat, *args, **kwargs):
	pass

and then implement HMC-in-Gibbs as a loop, following the algorithm:

  1. update subsample;
  2. update parameters (by HMC) evaluated on the subsample.

We want to add HMCECS to Numpyro; however, we need feedback on a good design for handling the subsampling. The current version has a strong assumption on the parameters of the model (i.e. first and second params are obs and feats), which we would like to avoid; Pyro has plates that support subsamples, with user-defined subsample schemes. Are there plans for something similar in Numpyro, and if so what would it take to implement?

@fehiepsi
Copy link
Member

fehiepsi commented Sep 7, 2020

Hi @OlaRonning, I think the current plate statement supports subsampling (we have a small test for it, also see this thread for a discussion). Could you be a little more explicit on "user-defined subsample schemes"? I think that we can implement some specific handlers for your purpose.

The current version has a strong assumption on the parameters of the model (i.e. first and second params are obs and feats), which we would like to avoid

Yeah, I agree. I think we can avoid it, by either using a handler or modifying your HMCECS implementation a bit. I am looking forward to seeing your PR. :)

@fehiepsi fehiepsi added the enhancement New feature or request label Sep 7, 2020
@LysSanzMoreta
Copy link
Contributor

I hope I can help, by "user-defined subsample schemes" we would like to be able to imitate NUTS on the automatic set up of the parameter "subsample size" (same as they do with step size for example). Therefore the subsample size will change /will be corrected according to the likelihood.

@fehiepsi
Copy link
Member

fehiepsi commented Sep 7, 2020

Unfortunately, I think that dynamic subsample size can only work with some masking mechanism. Kind of taking a subsample of max_subsample_size and masking out some log probabilities if the current subsample size is less than that max value. Could you provide some pseudo code so we can see if any additional mechanism is needed? I still believe that if something works with obs, feats then it will also work without them, as long as users provide enough information in the model or in the inference method. Btw, what does feats stand for?

@OlaRonning
Copy link
Member Author

Hi @fehiepsi,
Thanks for linking the subsampling discussion!

The scheme is to block update a subsample of indices by gibbs sampling. In pseudo python it would be something like:

n, _ = feats.shape  # features
m = n // 3 # subsample size
g = 4  # number of blocks in subsample
u = randint(0,n+1).sample(m)  # subsample
for _ in range(num_samples):
    choosen_block = randint(0,m//g)  # update blocks to induce correlation between u and u_new
    u_new = u
    u_new[chose_block] = randint(0,n).sample(m//g)  # update entire block (u is partition into g blocks)
    accpt_prob = likelihood(model(obs[u_new],feats[u_new,:]) / likelihood(model[u,:],model))
    u = u_new if bernoulli(accept_prob)

    hmc_update(model(obs[u], feats[u,:])  # evaulate grad_potential and potential on subsample

The likelihood is estimated by a precomputed map estimate of the models' params corrected by the current params evaluated on the subsample choose by u. As @LysSanzMoreta mentions we can update the subsample size on the fly by changing it so the variance of the (log) likelihood estimates stays close to one, however, it is not entirely clear how to change the subsample_size in a plate.

@fehiepsi
Copy link
Member

fehiepsi commented Sep 8, 2020

If you want to change subsample_size, I think you can do

def model(..., subsample_size=None):
    with numpyro.plate(..., subsample_size=subsample_size):
        ...

(later, we can write a handler to do that job because subsample_size is only specific to your HMC algorithm, rather than to the model - one solution is to expose create_plates in your HMC constructor like in Pyro autoguide or forecaster).

In addition, I think that you want something like

def model(full_data, full_feats, subsample_size=None):
    with numpyro.plate(..., subsample_size=subsample_size) as idx:
        data = full_data[idx]
        feats = full_feats[idx]
        ...

Not only that, IIUC you want to substitute idx by u_new too. That will be possible after #556. If that is what you need, please let me know. I will address that issue this weekend. :)

@OlaRonning
Copy link
Member Author

Great that there is already a way to make subsample_size variable.

The second model has the structure I was thinking of. We have to be able to substitute idx to make the hmc_update work with a plate, so if you have time for working on #556 that would be much appreciated!

@OlaRonning
Copy link
Member Author

Hi @fehiepsi,

Awesome, you merged #734! Would you recommend doing a substitute
on the trace as here to update the subsample indices? Or is there a way to augment the model, so that the fn for the plate is a custom function?

@fehiepsi
Copy link
Member

Hi @OlaRonning, I think you can replace your model by subsampled_model = substitute(model, data={"plate_name": subsample}) and compute log_likelihood for subsampled_model. You can also augment the model to something like

def augmented_model(*args, **kwargs):
     subsample = kwargs.pop("_subsample")
     return substitute(model, data={"plate_name": subsample})(*args, **kwargs)

I think there are some requirements in the model to make this work. For example, the model code might look something like

def model(x, y):
    with numpyro.plate('p', num_data, dim=-2) as p:
        x = x[p]  # or x = numpyro.subsample(x, event_dim=1)
        y = y[p]  # or y = numpyro.subsample(y, event_dim=1)
        numpyro.sample("obs", dist.Normal(x, 1), obs=y)

Do you have any blockers else? I will try to address them if possible.

@OlaRonning
Copy link
Member Author

I see, thanks for the examples. With your PR merged we have the tooling to make it work with plate.

We have a version, in an internal repo, that uses slicing of input parameters to do the subsampling. @LysSanzMoreta is currently porting it to a Numpyro fork. We will update and test with subsampling using a plate internally and make a subsequent PR once it's working.

@fehiepsi
Copy link
Member

That is great! I am looking forward to your PRs. 👍

@OlaRonning
Copy link
Member Author

I've made two handlers to introduce the estimated likelihood computation into the HMCECS proto-type @fehiepsi wrote. The first is the estimator that intercepts likelihood computations:

class estimator(Messenger):
    def __init__(self, fn, estimators, plate_sizes):
        self.estimators = estimators
        self.plate_sizes = plate_sizes
        super(estimator, self).__init__(fn)

    def process_message(self, msg):
        if msg['type'] == 'sample' and msg['is_observed'] and msg['cond_indep_stack']:  # <--- this conditional
            log_prob = msg['fn'].log_prob
            msg['scale'] = 1.
            msg['fn'].log_prob = lambda *args, **kwargs: \
                self.estimators[msg['name']](*args, name=msg['name'], z=_extract_params(msg['fn']), log_prob=log_prob,
                                             sizes=self.plate_sizes[msg['cond_indep_stack'][0].name],
                                             **kwargs)

and subsample_size for altering the subsample size:

class subsample_size(Messenger):
    def __init__(self, fn, plate_sizes, rng_key=None):
        super(subsample_size, self).__init__(fn)
        self.plate_sizes = plate_sizes
        self.rng_key = rng_key

    def process_message(self, msg):
        if msg['type'] == 'plate' and msg['args'] and msg["args"][0] > msg["args"][1]:
            if msg['name'] in self.plate_sizes:
                msg['args'] = self.plate_sizes[msg['name']]
                msg['value'] = _subsample_fn(*msg['args'], self.rng_key) if msg["args"][1] < msg["args"][
                    0] else jnp.arange(msg["args"][0])

Regarding estimator, is the conditional correct only to capture likelihood computations within a plate with subsampling?

I have a working toy example here with a difference estimator and Taylor proxy. I'll clean and check it with some more extensive examples and then make a WIP PR today or tomorrow.

@fehiepsi
Copy link
Member

fehiepsi commented Jan 8, 2021

That's awesome! FYI the naive version is implemented through the HMCGibbs interface. I guess you can just add some new arguments there to let users switch between different versions and to let users specify something like num_blocks. Could you let me know what are blockers with the current HMCGibbs api? (currently it works for conjugacy models, models with discrete variables, and naive subsampling but not works for MixedHMC #826, where multiple Gibbs steps are needed in one MCMC sample step).

is the conditional correct only to capture likelihood computations within a plate with subsampling

I think you also need to check if plate names in msg['cond_indep_stack'] belongs to plate_sites.

subsample_size for altering the subsample size

Does it change the size of subsample indices? I wonder if it will work with JAX...

estimator that intercepts likelihood computations

Great idea! I was concerned about how to do this in a clean way and this seems to be a solution. There is still one small issue here is this will change the attribute log_prob of a distribution instance. How about moving this at message level likes other message handlers? Something like msg['fn'] = foo(msg['fn'], ...) where the output of foo is a class with a method log_prob modified from the old msg['fn'].

@OlaRonning
Copy link
Member Author

The HMCGibbs interface looks great! I'll sync the aleatory fork and introduce likelihood estimators there.

Could you let me know what blockers with the current HMCGibbs API are?

I'll update this thread if I run into something; it seems more than sufficient from a quick inspection.

is the conditional correct only to capture likelihood computations within a plate with subsampling

I think you also need to check if plate names in msg['cond_indep_stack'] belong to plate_sites.

Yes, of course, thanks!

subsample_size for altering the subsample size

Does it change the size of subsample indices? I wonder if it will work with JAX...

The Taylor expansion is computed w.r.t. reference params (found using fx. MAP. or MLE) using the entire dataset, so it is used to changes the subsample_size to be equal to the dim param in the subsampling plate. This is only done once during init, which is not Jit compiled. I've not checked the handler when using Jit.

How about moving this at message level likes other message handlers?

Not sure what you mean; however, I think the wrapper msg['fn'] = foo(msg['fn'], ...) as you suggest is better solution!

@fehiepsi
Copy link
Member

fehiepsi commented Jan 9, 2021

moving this at message level likes other message handlers?

I meant to only modify the message, rather than doing inplace modification for some of its values.

This is only done once during init

Make sense to me. You can freely do many things under init as long as no global state is changed (there is _PYRO_STACK, which is global, and we use primitives to change it, but we should always free it before exit a jitted function - that is the reason why jit(fn(handlers)) works but handlers(jit(fn)) does not work for many cases). Your code and handlers look safe to me. I guess the main reason for having subsample_size is this assertation? If so, we can relax it to be a warning, so we can substitute subsample indices of any size.

@OlaRonning
Copy link
Member Author

OlaRonning commented Jan 12, 2021

Sorry for the delay; I found a problem with the Taylor expansion, which I'm trying to resolve. I will make a separate PR with the block update.

@fehiepsi
Copy link
Member

@OlaRonning I'm just curious on your plan to add more stuffs to HMCECS. FYI we are going to release 0.5 with the latest jax/jaxlib release (in 1 or 2 weeks to close current PRs). Do you want to have something else in that release?

@OlaRonning
Copy link
Member Author

Hi @fehiepsi.
I'm still working out how to include arithmetics when computing the Jacobian and Hessian of the likelihood for the Taylor expansion, when using a estimator handler. Logistic regression is a simple example:

model(data, obs):
  theta = sample('theta', Normal())
  with plate('N', data.shape[0], subsample_size=10) as idx:
    sample('obs', Bern(logit=data@theta), obs=obs)  # derivative should be include the jnp computation

While I suspect to solve this fairly soon, but I don't want to block your release schedule. So I'll wait to open more PR till after release .5 is out.

@fehiepsi fehiepsi changed the title HMCESC in Numpyro HMCECS in numpyro Jan 16, 2021
@OlaRonning
Copy link
Member Author

OlaRonning commented Feb 11, 2021

Closing this issue with #905 merged. Thanks, @fehiepsi and @LysSanzMoreta for all the help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants