-
Notifications
You must be signed in to change notification settings - Fork 415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Minibatch training when mll is an _ApproximateMarginalLogLikelihood #1438
Comments
This is a great suggestion, and IIRC @j-wilson has played around with this a bit before. Not sure what state that is in and whether it makes sense for him to push out a draft of this, or whether it's better to just start fresh with a PR on your end (seems reasonably straightforward all in all). @j-wilson any thoughts here?
Yeah that makes sense to me. I think if the user is using variational approximate GP models we can assume that they'd be able to manually specify the full batch training optimizer if needed. Another option would be to parse this somehow from the kwargs, but I don't think we need to worry about this for now. |
Here's a rough draft of what this might look like: main...jacobrgardner:botorch:stochastic_fitting The high level fitting works great (works on a piece of code I've been testing as well as our homebrew model fitting). Still a few TODOs even before code review:
@j-wilson @Balandat just let me know if you all don't have something further along already than this and I can open as a PR to track the TODOs there. (Edit: Oops, some automated thing must have run black on the files before commiting, sorry about the irrelevant parts of the linked diff) |
IIRC, @mshvartsman et al. use ApproximateGP with full-batch training. cc'ing in case they have any input on this. |
@jacobrgardner Hi Jake. Fully on board with you here. As Max mentioned, I've put together a draft for this as well. At a glance, it looks pretty similar to your implementation. The main difference seems to be that I actually just rewrote Aside from that, I have Would something like this work for your use cases? Regrading
If we end up with cases where |
@j-wilson Ah, yeah looks like we can just use TensorDataset there. In terms of the rest, how would do you envision the user specifying to use minibatch training? Would the idea be to do something like I guess I'm personally fine with essentially any of the proposed interfaces here. |
@jacobrgardner Good questions. I hadn't actually considered a solution like My thought had been to make a A typical call pattern might then look something like:
|
@j-wilson Okay, so if you all think a rewrite of Then, both |
@jacobrgardner Up for discussion. A naive implementation would probably see |
Not sure if that's a safe assumption :). As @saitcakmak said, in AEPsych we pretty much exclusively do full-batch small-data So my vote would be to either [a] retain the full batch with SAA default and warn if the data is too large, or [b] have a sensible user-adjustable cutoff to switch between the fitting strategies (similarly to how gpytorch switches between cholesky and CG for solves and logdets). I think I'd prefer [b] over [a], we'd just need to tune the cutoff. |
Hi folks. I've put together a PR (#1439) that implements the above. This ended up being a larger change than I had originally anticipated, but hopefully people will agree that these are the "right" changes (or at least trending in that direction). The best course of action in terms of balancing the specific functionality requested here with the overall design seemed to be to introduce a loss closure abstraction. This allows us to abstract away things like DataLoaders, while also enabling the user to specify custom routines for evaluating their loss functions. I haven't tested this yet, but I'm hopeful that we'll be able to use e.g. |
Sam has been having some good success using torchdynamo/torchinductor, would be interesting to see what this does here. |
🚀 Feature Request
Now that
fit_gpytorch_mll
exists using multiple dispatch, it seems like it'd be fairly straightforward to support minibatch training by registering afit_gpytorch_torch_stochastic
or similar as the optimizer for_ApproximateMarginalLogLikelihood
mlls.Motivation
Is your feature request related to a problem? Please describe.
As far as I can tell browsing the code, running
fit_gpytorch_mll
on anApproximateGPyTorchModel
would just use full batch training. As a result, we have (e.g., for latent space optimization tasks) typically been brewing our own GPyTorch models + training code still, despite the existence ofApproximateGPyTorchModel
. We're planning on submitting a PR with a latent space bayesopt tutorial, but I'd like it to be more BoTorch-y than it currently is -- right now the actual model handling is entirely outside of BoTorch for this reason.Pitch
Describe the solution you'd like
fit_gpytorch_torch_stochastic
inbotorch.optim.fit
that does minibatch training with a user specified batch size. For now, I was thinking this can just make a standardDataLoader
over the train targets and inputs -- handling the case wheretrain_targets
is actually a tuple might require more thought if we wanted to support that out of the gate.maxiter
in the stopping critereon would refer to a number of epochs of training.fit_gpytorch_torch_stochastic
as the default optimizer via a_fit_approximategp_stochastic
inbotorch.fit
to the dispatcher for(_ApproximateMarginalLogLikelihood, Likelihood, ApproximateGPyTorchModel)
.fit_gpytorch_torch
manually as the optimizer or (equivalently with negligible overhead) specifying the batch size to be the full N. One solution might be to just call the fallback fit if a minibatch size / optimizer isn't specified by the user? On the other hand, in the long run, it probably makes sense to assume by default that the user wants to do some kind of stochastic optimization if they're going to the trouble of using a variational approximate GP specifically rather than just e.g. an inducing point kernel on anExactGP
.Are you willing to open a pull request? (See CONTRIBUTING)
Yeah
The text was updated successfully, but these errors were encountered: