-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add types optim module #1942
Add types optim module #1942
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
|
||
from collections import namedtuple | ||
from collections.abc import Callable | ||
from typing import Any, TypeVar | ||
from typing import Any | ||
|
||
import jax | ||
from jax import jacfwd, lax, value_and_grad | ||
|
@@ -18,6 +18,7 @@ | |
import jax.numpy as jnp | ||
from jax.scipy.optimize import minimize | ||
from jax.tree_util import register_pytree_node | ||
from jax.typing import ArrayLike | ||
|
||
__all__ = [ | ||
"Adam", | ||
|
@@ -31,12 +32,12 @@ | |
"SM3", | ||
] | ||
|
||
_Params = TypeVar("_Params") | ||
_OptState = TypeVar("_OptState") | ||
_IterOptState = tuple[int, _OptState] | ||
_Params = Any | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why use Any for _Params but not for OptState There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If I leave numpyro/optim.py:103: error: Type variable "numpyro.optim._Params" is unbound This line is def eval_and_update(
self,
fn: Callable[[Any], tuple],
state: _IterOptState,
forward_mode_differentiation: bool = False,
) -> tuple[tuple[Any, Any], _IterOptState]:
...
params: _Params = self.get_params(state). # <- HERE!
(out, aux), grads = _value_and_grad(
fn, x=params, forward_mode_differentiation=forward_mode_differentiation
) The problem is that This is one of the references that helped me clarify this issue https://stackoverflow.com/questions/68603585/mypy-why-does-typevar-not-work-without-bound-specified There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How about just using `params = self.get_params(state)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately this does not work (it was my first try) :( numpyro/optim.py:103: error: Need type annotation for "params" [var-annotated]
numpyro/optim.py:125: error: Type variable "numpyro.optim._Params" is unbound [valid-type] The reason is because of the logic above. We can not use TypeVar internally :/ There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have made another attempt which I think serves our purpose in de835c6 I had to skip one type check in def get_params(opt_state: _MinimizeState) -> _Params: # type: ignore[type-var]. # <- HERE
flat_params, unravel_fn = opt_state
return unravel_fn(flat_params) Better? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess we can use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure! Changed in e24a7a2 :) |
||
_OptState = Any | ||
_IterOptState = tuple[ArrayLike, _OptState] | ||
|
||
|
||
def _value_and_grad(f, x, forward_mode_differentiation=False): | ||
def _value_and_grad(f, x, forward_mode_differentiation=False) -> tuple: | ||
if forward_mode_differentiation: | ||
|
||
def _wrapper(x): | ||
|
@@ -51,6 +52,9 @@ def _wrapper(x): | |
|
||
class _NumPyroOptim(object): | ||
def __init__(self, optim_fn: Callable, *args, **kwargs) -> None: | ||
self.init_fn: Callable[[_Params], _IterOptState] | ||
self.update_fn: Callable[[ArrayLike, _Params, _OptState], _OptState] | ||
self.get_params_fn: Callable[[_OptState], _Params] | ||
self.init_fn, self.update_fn, self.get_params_fn = optim_fn(*args, **kwargs) | ||
|
||
def init(self, params: _Params) -> _IterOptState: | ||
|
@@ -80,7 +84,7 @@ def eval_and_update( | |
fn: Callable[[Any], tuple], | ||
state: _IterOptState, | ||
forward_mode_differentiation: bool = False, | ||
): | ||
) -> tuple[tuple[Any, Any], _IterOptState]: | ||
""" | ||
Performs an optimization step for the objective function `fn`. | ||
For most optimizers, the update is performed based on the gradient | ||
|
@@ -96,7 +100,7 @@ def eval_and_update( | |
:param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. | ||
:return: a pair of the output of objective function and the new optimizer state. | ||
""" | ||
params = self.get_params(state) | ||
params: _Params = self.get_params(state) | ||
(out, aux), grads = _value_and_grad( | ||
fn, x=params, forward_mode_differentiation=forward_mode_differentiation | ||
) | ||
|
@@ -107,7 +111,7 @@ def eval_and_stable_update( | |
fn: Callable[[Any], tuple], | ||
state: _IterOptState, | ||
forward_mode_differentiation: bool = False, | ||
): | ||
) -> tuple[tuple[Any, Any], _IterOptState]: | ||
""" | ||
Like :meth:`eval_and_update` but when the value of the objective function | ||
or the gradients are not finite, we will not update the input `state` | ||
|
@@ -118,7 +122,7 @@ def eval_and_stable_update( | |
:param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. | ||
:return: a pair of the output of objective function and the new optimizer state. | ||
""" | ||
params = self.get_params(state) | ||
params: _Params = self.get_params(state) | ||
(out, aux), grads = _value_and_grad( | ||
fn, x=params, forward_mode_differentiation=forward_mode_differentiation | ||
) | ||
|
@@ -141,7 +145,7 @@ def get_params(self, state: _IterOptState) -> _Params: | |
return self.get_params_fn(opt_state) | ||
|
||
|
||
def _add_doc(fn): | ||
def _add_doc(fn) -> Callable[[Any], Any]: | ||
def _wrapped(cls): | ||
cls.__doc__ = "Wrapper class for the JAX optimizer: :func:`~jax.example_libraries.optimizers.{}`".format( | ||
fn.__name__ | ||
|
@@ -153,7 +157,7 @@ def _wrapped(cls): | |
|
||
@_add_doc(optimizers.adam) | ||
class Adam(_NumPyroOptim): | ||
def __init__(self, *args, **kwargs): | ||
def __init__(self, *args, **kwargs) -> None: | ||
super(Adam, self).__init__(optimizers.adam, *args, **kwargs) | ||
|
||
|
||
|
@@ -170,11 +174,11 @@ class ClippedAdam(_NumPyroOptim): | |
https://arxiv.org/abs/1412.6980 | ||
""" | ||
|
||
def __init__(self, *args, clip_norm=10.0, **kwargs): | ||
def __init__(self, *args, clip_norm: float = 10.0, **kwargs) -> None: | ||
self.clip_norm = clip_norm | ||
super(ClippedAdam, self).__init__(optimizers.adam, *args, **kwargs) | ||
|
||
def update(self, g, state): | ||
def update(self, g: _Params, state: _IterOptState) -> _IterOptState: | ||
i, opt_state = state | ||
# clip norm | ||
g = jax.tree.map(lambda g_: jnp.clip(g_, -self.clip_norm, self.clip_norm), g) | ||
|
@@ -184,39 +188,39 @@ def update(self, g, state): | |
|
||
@_add_doc(optimizers.adagrad) | ||
class Adagrad(_NumPyroOptim): | ||
def __init__(self, *args, **kwargs): | ||
def __init__(self, *args, **kwargs) -> None: | ||
super(Adagrad, self).__init__(optimizers.adagrad, *args, **kwargs) | ||
|
||
|
||
@_add_doc(optimizers.momentum) | ||
class Momentum(_NumPyroOptim): | ||
def __init__(self, *args, **kwargs): | ||
def __init__(self, *args, **kwargs) -> None: | ||
super(Momentum, self).__init__(optimizers.momentum, *args, **kwargs) | ||
|
||
|
||
@_add_doc(optimizers.rmsprop) | ||
class RMSProp(_NumPyroOptim): | ||
def __init__(self, *args, **kwargs): | ||
def __init__(self, *args, **kwargs) -> None: | ||
super(RMSProp, self).__init__(optimizers.rmsprop, *args, **kwargs) | ||
|
||
|
||
@_add_doc(optimizers.rmsprop_momentum) | ||
class RMSPropMomentum(_NumPyroOptim): | ||
def __init__(self, *args, **kwargs): | ||
def __init__(self, *args, **kwargs) -> None: | ||
super(RMSPropMomentum, self).__init__( | ||
optimizers.rmsprop_momentum, *args, **kwargs | ||
) | ||
|
||
|
||
@_add_doc(optimizers.sgd) | ||
class SGD(_NumPyroOptim): | ||
def __init__(self, *args, **kwargs): | ||
def __init__(self, *args, **kwargs) -> None: | ||
super(SGD, self).__init__(optimizers.sgd, *args, **kwargs) | ||
|
||
|
||
@_add_doc(optimizers.sm3) | ||
class SM3(_NumPyroOptim): | ||
def __init__(self, *args, **kwargs): | ||
def __init__(self, *args, **kwargs) -> None: | ||
super(SM3, self).__init__(optimizers.sm3, *args, **kwargs) | ||
|
||
|
||
|
@@ -225,24 +229,36 @@ def __init__(self, *args, **kwargs): | |
# and pass `unravel_fn` around. | ||
# When arbitrary pytree is supported in JAX, we can just simply use | ||
# identity functions for `init_fn` and `get_params`. | ||
_MinimizeState = namedtuple("MinimizeState", ["flat_params", "unravel_fn"]) | ||
class _MinimizeState(namedtuple("_MinimizeState", ["flat_params", "unravel_fn"])): | ||
flat_params: ArrayLike | ||
unravel_fn: Callable[[ArrayLike], _Params] | ||
|
||
|
||
register_pytree_node( | ||
_MinimizeState, | ||
lambda state: ((state.flat_params,), (state.unravel_fn,)), | ||
lambda data, xs: _MinimizeState(xs[0], data[0]), | ||
) | ||
|
||
|
||
def _minimize_wrapper(): | ||
def init_fn(params): | ||
def _minimize_wrapper() -> ( | ||
tuple[ | ||
Callable[[_Params], _MinimizeState], | ||
Callable[[Any, Any, _MinimizeState], _MinimizeState], | ||
Callable[[_MinimizeState], _Params], | ||
] | ||
): | ||
def init_fn(params: _Params) -> _MinimizeState: | ||
flat_params, unravel_fn = ravel_pytree(params) | ||
return _MinimizeState(flat_params, unravel_fn) | ||
|
||
def update_fn(i, grad_tree, opt_state): | ||
def update_fn( | ||
i: ArrayLike, grad_tree: ArrayLike, opt_state: _MinimizeState | ||
) -> _MinimizeState: | ||
# we don't use update_fn in Minimize, so let it do nothing | ||
return opt_state | ||
|
||
def get_params(opt_state): | ||
def get_params(opt_state: _MinimizeState) -> _Params: | ||
flat_params, unravel_fn = opt_state | ||
return unravel_fn(flat_params) | ||
|
||
|
@@ -289,7 +305,7 @@ class Minimize(_NumPyroOptim): | |
>>> assert_allclose(quantiles["b"], 3., atol=1e-3) | ||
""" | ||
|
||
def __init__(self, method="BFGS", **kwargs): | ||
def __init__(self, method="BFGS", **kwargs) -> None: | ||
super().__init__(_minimize_wrapper) | ||
self._method = method | ||
self._kwargs = kwargs | ||
|
@@ -298,8 +314,8 @@ def eval_and_update( | |
self, | ||
fn: Callable[[Any], tuple], | ||
state: _IterOptState, | ||
forward_mode_differentiation=False, | ||
): | ||
forward_mode_differentiation: bool = False, | ||
) -> tuple[tuple[Any, None], _IterOptState]: | ||
i, (flat_params, unravel_fn) = state | ||
|
||
def loss_fn(x): | ||
|
@@ -333,17 +349,19 @@ def optax_to_numpyro(transformation) -> _NumPyroOptim: | |
""" | ||
import optax | ||
|
||
def init_fn(params): | ||
def init_fn(params: _Params) -> tuple[_Params, Any]: | ||
opt_state = transformation.init(params) | ||
return params, opt_state | ||
|
||
def update_fn(step, grads, state): | ||
def update_fn( | ||
step: ArrayLike, grads: ArrayLike, state: tuple[_Params, Any] | ||
) -> tuple[_Params, Any]: | ||
params, opt_state = state | ||
updates, opt_state = transformation.update(grads, opt_state, params) | ||
updated_params = optax.apply_updates(params, updates) | ||
return updated_params, opt_state | ||
|
||
def get_params_fn(state): | ||
def get_params_fn(state: tuple[_Params, Any]) -> _Params: | ||
params, _ = state | ||
return params | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to specify a bound or a specific type and this is currently the "best" solution for pytrees, see jax-ml/jax#3340 (comment)