-
Notifications
You must be signed in to change notification settings - Fork 415
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: X-link: facebook/Ax#1191 Pull Request resolved: #1439 This diff acts as follow-up to the recent model fitting refactor. The previous update focused on the high-level logic used to determine which fitting routines to use for which MLLs. This diff refactors the internal machinery used to evaluate forward-backward passes (producing losses and gradients, respectively) during optimization. The solution we have opted for is to abstract away the evaluation process by relying on closures. In most cases, these closures are automatically constructed by composing simpler, multiply-dispatched base functions. Reviewed By: Balandat Differential Revision: D39101211 fbshipit-source-id: c2058a387fd74058073cfe73c9404d2df2f9b55a
- Loading branch information
1 parent
7613cd2
commit 6239291
Showing
35 changed files
with
3,927 additions
and
1,648 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from botorch.optim.closures.core import ( | ||
ForwardBackwardClosure, | ||
NdarrayOptimizationClosure, | ||
) | ||
from botorch.optim.closures.model_closures import ( | ||
get_loss_closure, | ||
get_loss_closure_with_grads, | ||
) | ||
|
||
|
||
__all__ = [ | ||
"ForwardBackwardClosure", | ||
"get_loss_closure", | ||
"get_loss_closure_with_grads", | ||
"NdarrayOptimizationClosure", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,184 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
"""Core methods for building closures in torch and interfacing with numpy.""" | ||
|
||
from __future__ import annotations | ||
|
||
from functools import partial | ||
from typing import Any, Callable, Dict, Optional, Sequence, Tuple | ||
|
||
import torch | ||
from botorch.optim.utils import ( | ||
_handle_numerical_errors, | ||
get_tensors_as_ndarray_1d, | ||
set_tensors_from_ndarray_1d, | ||
) | ||
from botorch.optim.utils.numpy_utils import as_ndarray | ||
from botorch.utils.context_managers import zero_grad_ctx | ||
from numpy import float64 as np_float64, full as np_full, ndarray, zeros as np_zeros | ||
from torch import Tensor | ||
|
||
|
||
class ForwardBackwardClosure: | ||
r"""Wrapper for fused forward and backward closures.""" | ||
|
||
def __init__( | ||
self, | ||
forward: Callable[[], Tensor], | ||
parameters: Dict[str, Tensor], | ||
backward: Callable[[Tensor], None] = Tensor.backward, | ||
reducer: Optional[Callable[[Tensor], Tensor]] = torch.sum, | ||
callback: Optional[Callable[[Tensor, Sequence[Optional[Tensor]]], None]] = None, | ||
context_manager: Callable = None, # pyre-ignore [9] | ||
) -> None: | ||
r"""Initializes a ForwardBackwardClosure instance. | ||
Args: | ||
closure: Callable that returns a tensor. | ||
parameters: A dictionary of tensors whose `grad` fields are to be returned. | ||
backward: Callable that takes the (reduced) output of `forward` and sets the | ||
`grad` attributes of tensors in `parameters`. | ||
reducer: Optional callable used to reduce the output of the forward pass. | ||
callback: Optional callable that takes the reduced output of `forward` and | ||
the gradients of `parameters` as positional arguments. | ||
context_manager: A ContextManager used to wrap each forward-backward call. | ||
When passed as `None`, `context_manager` defaults to a `zero_grad_ctx` | ||
that zeroes the gradients of `parameters` upon entry. | ||
""" | ||
if context_manager is None: | ||
context_manager = partial(zero_grad_ctx, parameters) | ||
|
||
self.forward = forward | ||
self.backward = backward | ||
self.parameters = parameters | ||
self.reducer = reducer | ||
self.callback = callback | ||
self.context_manager = context_manager | ||
|
||
def __call__(self, **kwargs: Any) -> Tuple[Tensor, Tuple[Optional[Tensor], ...]]: | ||
with self.context_manager(): | ||
values = self.forward(**kwargs) | ||
value = values if self.reducer is None else self.reducer(values) | ||
self.backward(value) | ||
|
||
grads = tuple(param.grad for param in self.parameters.values()) | ||
if self.callback: | ||
self.callback(value, grads) | ||
|
||
return value, grads | ||
|
||
|
||
class NdarrayOptimizationClosure: | ||
r"""Adds stateful behavior and a numpy.ndarray-typed API to a closure with an | ||
expected return type Tuple[Tensor, Union[Tensor, Sequence[Optional[Tensor]]]].""" | ||
|
||
def __init__( | ||
self, | ||
closure: Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]], | ||
parameters: Dict[str, Tensor], | ||
as_array: Callable[[Tensor], ndarray] = None, # pyre-ignore [9] | ||
as_tensor: Callable[[ndarray], Tensor] = torch.as_tensor, | ||
get_state: Callable[[], ndarray] = None, # pyre-ignore [9] | ||
set_state: Callable[[ndarray], None] = None, # pyre-ignore [9] | ||
fill_value: float = 0.0, | ||
persistent: bool = True, | ||
) -> None: | ||
r"""Initializes a NdarrayOptimizationClosure instance. | ||
Args: | ||
closure: A ForwardBackwardClosure instance. | ||
parameters: A dictionary of tensors representing the closure's state. | ||
Expected to correspond with the first `len(parameters)` optional | ||
gradient tensors returned by `closure`. | ||
as_array: Callable used to convert tensors to ndarrays. | ||
as_tensor: Callable used to convert ndarrays to tensors. | ||
get_state: Callable that returns the closure's state as an ndarray. When | ||
passed as `None`, defaults to calling `get_tensors_as_ndarray_1d` | ||
on `closure.parameters` while passing `as_array` (if given by the user). | ||
set_state: Callable that takes a 1-dimensional ndarray and sets the | ||
closure's state. When passed as `None`, `set_state` defaults to | ||
calling `set_tensors_from_ndarray_1d` with `closure.parameters` and | ||
a given ndarray while passing `as_tensor`. | ||
fill_value: Fill value for parameters whose gradients are None. In most | ||
cases, `fill_value` should either be zero or NaN. | ||
persistent: Boolean specifying whether an ndarray should be retained | ||
as a persistent buffer for gradients. | ||
""" | ||
if get_state is None: | ||
# Note: Numpy supports copying data between ndarrays with different dtypes. | ||
# Hence, our default behavior need not coerce the ndarray represenations of | ||
# tensors in `parameters` to float64 when copying over data. | ||
_as_array = as_ndarray if as_array is None else as_array | ||
get_state = partial( | ||
get_tensors_as_ndarray_1d, parameters, as_array=_as_array | ||
) | ||
|
||
if as_array is None: # per the note, do this after resolving `get_state` | ||
as_array = partial(as_ndarray, dtype=np_float64) | ||
|
||
if set_state is None: | ||
set_state = partial( | ||
set_tensors_from_ndarray_1d, parameters, as_tensor=as_tensor | ||
) | ||
|
||
self.closure = closure | ||
self.parameters = parameters | ||
|
||
self.as_array = as_ndarray | ||
self.as_tensor = as_tensor | ||
self._get_state = get_state | ||
self._set_state = set_state | ||
|
||
self.fill_value = fill_value | ||
self.persistent = persistent | ||
self._gradient_ndarray: Optional[ndarray] = None | ||
|
||
def __call__( | ||
self, state: Optional[ndarray] = None, **kwargs: Any | ||
) -> Tuple[ndarray, ndarray]: | ||
if state is not None: | ||
self.state = state | ||
|
||
try: | ||
value_tensor, grad_tensors = self.closure(**kwargs) | ||
value = self.as_array(value_tensor) | ||
grads = self._get_gradient_ndarray(fill_value=self.fill_value) | ||
index = 0 | ||
for param, grad in zip(self.parameters.values(), grad_tensors): | ||
size = param.numel() | ||
if grad is not None: | ||
grads[index : index + size] = self.as_array(grad.view(-1)) | ||
index += size | ||
except RuntimeError as e: | ||
value, grads = _handle_numerical_errors(error=e, x=self.state) | ||
|
||
return value, grads | ||
|
||
@property | ||
def state(self) -> ndarray: | ||
return self._get_state() | ||
|
||
@state.setter | ||
def state(self, state: ndarray) -> None: | ||
self._set_state(state) | ||
|
||
def _get_gradient_ndarray(self, fill_value: Optional[float] = None) -> ndarray: | ||
if self.persistent and self._gradient_ndarray is not None: | ||
if fill_value is not None: | ||
self._gradient_ndarray.fill(fill_value) | ||
return self._gradient_ndarray | ||
|
||
size = sum(param.numel() for param in self.parameters.values()) | ||
array = ( | ||
np_zeros(size) | ||
if fill_value is None or fill_value == 0.0 | ||
else np_full(size, fill_value) | ||
) | ||
if self.persistent: | ||
self._gradient_ndarray = array | ||
|
||
return array |
Oops, something went wrong.