-
Notifications
You must be signed in to change notification settings - Fork 415
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Botorch closures #1439
Botorch closures #1439
Conversation
This pull request was exported from Phabricator. Differential Revision: D39101211 |
This pull request was exported from Phabricator. Differential Revision: D39101211 |
botorch/fit.py
Outdated
data_loader: Convience keyword for passing in a DataLoader instance or dict of | ||
keyword arguments passed to `get_data_loader` to obtain one. May only be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It feels a bit weird that data_loader
can either be an instance or a set of kwargs passed to a factory function.
botorch/optim/closures.py
Outdated
|
||
dispatcher = Dispatcher("get_loss_closure") | ||
NoneType = type(None) | ||
TLossClosure = Callable[[], Tensor] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the right type if you're forwarding kwargs to the MLL?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really, no. If we end up allowing kwargs
to be passed to loss closures (debatable since torch.jit.script
does not support this), we'll make TLossClosure
a typing.Protocol
.
botorch/optim/fit.py
Outdated
if isinstance(torch_optimizer, (type, partial)): | ||
torch_optimizer = torch_optimizer(params=list(param_dict.values())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not the biggest fan of the pattern of allowing this to be either an instance or a factory function. Not sure if there is an elegant alternative though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would factory only be better or worse? Note that instances can still be passed as lambda **_: instance
, but the typing is now more explicit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My $0.02: being able to pass an Instance is nicer than the factory, if only because then I don't necessarily need all the parameters I want to optimize to live somewhere on mll
. There's even a reasonable argument to be made for accepting lists of Optimizers and stepping all of them, to support e.g. NGD or different Optimizers / learning rates for the GP vs a DNN
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
being able to pass an Instance is nicer than the factory, if only because then I don't necessarily need all the parameters I want to optimize to live somewhere on mll.
So, here's what we can do. We can split up fit_gpytorch_torch
into a pair of methods: i) a generic torch-based gradient descent method and ii) a wrapper that provides mll-based model fitting conveniences. This seems to address your first point.
There's even a reasonable argument to be made for accepting lists of Optimizers and stepping all of them, to support e.g. NGD or different Optimizers / learning rates for the GP vs a DNN
Interesting idea. I like that this eliminates the need for step_limit
. If you wanted to go further, you could pass in a list of OptimizationStep
objects with slots for closure
, optimizer
, etc . Unclear to me whether we run into any scheduling issues here.
botorch/optim/fit.py
Outdated
torch_scheduler.step() | ||
|
||
loss = None if loss is None else loss.detach().cpu().item() | ||
return mll, {"fopt": loss, "wall_time": monotonic() - start_time} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we include iterations
here for backwards compatibility? I guess this is sufficiently deep in the stack that we don't need to worry too much about this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Callbacks Are All You Need. Wilson et al., 2022.
Jokes aside: Is it just a backward compatibility thing or are you thinking that some conveniences would be useful here?
botorch/optim/numpy_converter.py
Outdated
torch_closure: Optional[TLossClosure] = None, | ||
parameter_setter: Callable[[np.ndarray], Any] = set_params_with_array, | ||
) -> Callable[[], Tuple[np.ndarray, np.ndarray]]: | ||
if torch_closure is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a docstring here?
This looks like a pretty reasonable solution to me! I actually like the ability to roll my own loss closures and optimizers here, and think it's worth the extra engineering. @nataliemaus has code that sometimes wants to train a GP end to end with a VAE included in the ELBO, and sometimes only wants to train the GP, so this should let us switch even more code into using BoTorch routines. |
This pull request was exported from Phabricator. Differential Revision: D39101211 |
1 similar comment
This pull request was exported from Phabricator. Differential Revision: D39101211 |
This pull request was exported from Phabricator. Differential Revision: D39101211 |
Codecov Report
@@ Coverage Diff @@
## main #1439 +/- ##
==========================================
Coverage 100.00% 100.00%
==========================================
Files 134 143 +9
Lines 12402 12755 +353
==========================================
+ Hits 12402 12755 +353
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
This pull request was exported from Phabricator. Differential Revision: D39101211 |
This pull request was exported from Phabricator. Differential Revision: D39101211 |
9 similar comments
This pull request was exported from Phabricator. Differential Revision: D39101211 |
This pull request was exported from Phabricator. Differential Revision: D39101211 |
This pull request was exported from Phabricator. Differential Revision: D39101211 |
This pull request was exported from Phabricator. Differential Revision: D39101211 |
This pull request was exported from Phabricator. Differential Revision: D39101211 |
This pull request was exported from Phabricator. Differential Revision: D39101211 |
This pull request was exported from Phabricator. Differential Revision: D39101211 |
This pull request was exported from Phabricator. Differential Revision: D39101211 |
This pull request was exported from Phabricator. Differential Revision: D39101211 |
Summary: X-link: facebook/Ax#1191 Pull Request resolved: #1439 This diff acts as follow-up to the recent model fitting refactor. The previous update focused on the high-level logic used to determine which fitting routines to use for which MLLs. This diff refactors the internal machinery used to evaluate forward-backward passes (producing losses and gradients, respectively) during optimization. The solution we have opted for is to abstract away the evaluation process by relying on closures. In most cases, these closures are automatically constructed by composing simpler, multiply-dispatched base functions. Reviewed By: Balandat Differential Revision: D39101211 fbshipit-source-id: f4ec341228f9f16a327307ff398c3fb8839a3de2
This pull request was exported from Phabricator. Differential Revision: D39101211 |
Summary: Pull Request resolved: #1191 X-link: pytorch/botorch#1439 This diff acts as follow-up to the recent model fitting refactor. The previous update focused on the high-level logic used to determine which fitting routines to use for which MLLs. This diff refactors the internal machinery used to evaluate forward-backward passes (producing losses and gradients, respectively) during optimization. The solution we have opted for is to abstract away the evaluation process by relying on closures. In most cases, these closures are automatically constructed by composing simpler, multiply-dispatched base functions. Reviewed By: Balandat Differential Revision: D39101211 fbshipit-source-id: c2058a387fd74058073cfe73c9404d2df2f9b55a
Summary: ## Motivation Removes everything deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: pytorch#1995 Test Plan: Existing units; make sure codecov has not regressed from deleting tests. ## Related PRs pytorch#1439 Reviewed By: Balandat Differential Revision: D48738275 Pulled By: esantorella fbshipit-source-id: 38f39d185c0cc843f4be7ccd420c0430d3fa1fcd
Summary: ## Motivation Removes everything deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: pytorch#1995 Test Plan: Existing units; make sure codecov has not regressed from deleting tests. ## Related PRs pytorch#1439 Differential Revision: https://internalfb.com/D48738275 fbshipit-source-id: 4cb19467d42d782c4abe95810e48428c193bef99
Summary: ## Motivation Removes everything deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: pytorch#1995 Test Plan: Existing units; make sure codecov has not regressed from deleting tests. ## Related PRs pytorch#1439 Reviewed By: Balandat Differential Revision: D48738275 Pulled By: esantorella fbshipit-source-id: 828db010206db8d83071f53576ffba5b6dd49ea7
…h_torch from BoTorch (pytorch#1995) Summary: ## Motivation Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: pytorch#1995 Test Plan: Existing units; make sure codecov has not regressed from deleting tests. ## Related PRs pytorch#1439 Differential Revision: https://internalfb.com/D48738275 fbshipit-source-id: 699d54a6382bd18996624d45a0ba1a564d2e0390
…h_torch from BoTorch (pytorch#1995) Summary: ## Motivation Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: pytorch#1995 Test Plan: Existing units; make sure codecov has not regressed from deleting tests. ## Related PRs pytorch#1439 Differential Revision: https://internalfb.com/D48738275 fbshipit-source-id: 382aacf301178208f21c18071d104085e0f7f73a
…h_torch from BoTorch (pytorch#1995) Summary: ## Motivation Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: pytorch#1995 Test Plan: Existing units; make sure codecov has not regressed from deleting tests. ## Related PRs pytorch#1439 Differential Revision: https://internalfb.com/D48738275 fbshipit-source-id: bcdf275627537c78243f8a01bd291bed3fc764e8
…h_torch from BoTorch (pytorch#1995) Summary: ## Motivation Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: pytorch#1995 Test Plan: Existing units; make sure codecov has not regressed from deleting tests. ## Related PRs pytorch#1439 Reviewed By: Balandat Differential Revision: D48738275 Pulled By: esantorella fbshipit-source-id: bd9874f91cb90a8745943d2ae64facacc52d7a0e
…h_torch from BoTorch (pytorch#1995) Summary: ## Motivation Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: pytorch#1995 Test Plan: Existing units; make sure codecov has not regressed from deleting tests. ## Related PRs pytorch#1439 Reviewed By: Balandat Differential Revision: D48738275 Pulled By: esantorella fbshipit-source-id: dae2be3547dec7406ad342198521897cf0afdca8
…h_torch from BoTorch (pytorch#1995) Summary: ## Motivation Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: pytorch#1995 Test Plan: Existing units; make sure codecov has not regressed from deleting tests. ## Related PRs pytorch#1439 Reviewed By: Balandat Differential Revision: D48738275 Pulled By: esantorella fbshipit-source-id: cc34165b35429fe3967f635ede73c40c32ff2730
…h_torch from BoTorch (pytorch#1995) Summary: ## Motivation Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: pytorch#1995 Test Plan: Existing units; make sure codecov has not regressed from deleting tests. ## Related PRs pytorch#1439 Differential Revision: https://internalfb.com/D48738275 fbshipit-source-id: ca0b2f52fbbcff8c88a624c7f89c63cb06928f53
…rch (pytorch#1995) Summary: ## Motivation Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: pytorch#1995 Pull Request resolved: pytorch#2041 Test Plan: Existing units; make sure codecov has not regressed from deleting tests. ## Related PRs pytorch#1439 Differential Revision: https://internalfb.com/D50276821 fbshipit-source-id: c11ff61ad694e3e4db04ebf338c0745f22a61f79
…rch (pytorch#1995) Summary: ## Motivation Removes most of what was deprecated in pytorch#1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: pytorch#1995 Pull Request resolved: pytorch#2041 Test Plan: Existing units; make sure codecov has not regressed from deleting tests. ## Related PRs pytorch#1439 Differential Revision: https://internalfb.com/D50276821 fbshipit-source-id: 2c90b1b254a5246160b9848812155de68b2a00ce
…rch (#1995) Summary: ## Motivation Removes most of what was deprecated in #1439, in version 0.8.0, except for `_get_extra_mll_args` and `fit_gpytorch_model`. Removal was straightforward since the functionality was not used and relevant unit tests were clearly labeled and self-contained. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #1995 Pull Request resolved: #2041 Test Plan: Existing units; make sure codecov has not regressed from deleting tests. ## Related PRs #1439 Reviewed By: saitcakmak Differential Revision: D50276821 Pulled By: esantorella fbshipit-source-id: 45c8f082cdd00cb3bb78a342e38e3f0e751cf56f
Summary:
Changelog:
fit_gptorch_torch
rewritefit_gyptorch_mll
dispatch for ApproximateGPsDifferential Revision: D39101211