diff --git a/numpyro/optim.py b/numpyro/optim.py index eab3f594a..8d1717b16 100644 --- a/numpyro/optim.py +++ b/numpyro/optim.py @@ -9,7 +9,7 @@ from collections import namedtuple from collections.abc import Callable -from typing import Any, TypeVar +from typing import Any import jax from jax import jacfwd, lax, value_and_grad @@ -18,6 +18,7 @@ import jax.numpy as jnp from jax.scipy.optimize import minimize from jax.tree_util import register_pytree_node +from jax.typing import ArrayLike __all__ = [ "Adam", @@ -31,12 +32,12 @@ "SM3", ] -_Params = TypeVar("_Params") -_OptState = TypeVar("_OptState") -_IterOptState = tuple[int, _OptState] +_Params = Any +_OptState = Any +_IterOptState = tuple[ArrayLike, _OptState] -def _value_and_grad(f, x, forward_mode_differentiation=False): +def _value_and_grad(f, x, forward_mode_differentiation=False) -> tuple: if forward_mode_differentiation: def _wrapper(x): @@ -51,6 +52,9 @@ def _wrapper(x): class _NumPyroOptim(object): def __init__(self, optim_fn: Callable, *args, **kwargs) -> None: + self.init_fn: Callable[[_Params], _IterOptState] + self.update_fn: Callable[[ArrayLike, _Params, _OptState], _OptState] + self.get_params_fn: Callable[[_OptState], _Params] self.init_fn, self.update_fn, self.get_params_fn = optim_fn(*args, **kwargs) def init(self, params: _Params) -> _IterOptState: @@ -80,7 +84,7 @@ def eval_and_update( fn: Callable[[Any], tuple], state: _IterOptState, forward_mode_differentiation: bool = False, - ): + ) -> tuple[tuple[Any, Any], _IterOptState]: """ Performs an optimization step for the objective function `fn`. For most optimizers, the update is performed based on the gradient @@ -96,7 +100,7 @@ def eval_and_update( :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. :return: a pair of the output of objective function and the new optimizer state. """ - params = self.get_params(state) + params: _Params = self.get_params(state) (out, aux), grads = _value_and_grad( fn, x=params, forward_mode_differentiation=forward_mode_differentiation ) @@ -107,7 +111,7 @@ def eval_and_stable_update( fn: Callable[[Any], tuple], state: _IterOptState, forward_mode_differentiation: bool = False, - ): + ) -> tuple[tuple[Any, Any], _IterOptState]: """ Like :meth:`eval_and_update` but when the value of the objective function or the gradients are not finite, we will not update the input `state` @@ -118,7 +122,7 @@ def eval_and_stable_update( :param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation. :return: a pair of the output of objective function and the new optimizer state. """ - params = self.get_params(state) + params: _Params = self.get_params(state) (out, aux), grads = _value_and_grad( fn, x=params, forward_mode_differentiation=forward_mode_differentiation ) @@ -141,7 +145,7 @@ def get_params(self, state: _IterOptState) -> _Params: return self.get_params_fn(opt_state) -def _add_doc(fn): +def _add_doc(fn) -> Callable[[Any], Any]: def _wrapped(cls): cls.__doc__ = "Wrapper class for the JAX optimizer: :func:`~jax.example_libraries.optimizers.{}`".format( fn.__name__ @@ -153,7 +157,7 @@ def _wrapped(cls): @_add_doc(optimizers.adam) class Adam(_NumPyroOptim): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super(Adam, self).__init__(optimizers.adam, *args, **kwargs) @@ -170,11 +174,11 @@ class ClippedAdam(_NumPyroOptim): https://arxiv.org/abs/1412.6980 """ - def __init__(self, *args, clip_norm=10.0, **kwargs): + def __init__(self, *args, clip_norm: float = 10.0, **kwargs) -> None: self.clip_norm = clip_norm super(ClippedAdam, self).__init__(optimizers.adam, *args, **kwargs) - def update(self, g, state): + def update(self, g: _Params, state: _IterOptState) -> _IterOptState: i, opt_state = state # clip norm g = jax.tree.map(lambda g_: jnp.clip(g_, -self.clip_norm, self.clip_norm), g) @@ -184,25 +188,25 @@ def update(self, g, state): @_add_doc(optimizers.adagrad) class Adagrad(_NumPyroOptim): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super(Adagrad, self).__init__(optimizers.adagrad, *args, **kwargs) @_add_doc(optimizers.momentum) class Momentum(_NumPyroOptim): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super(Momentum, self).__init__(optimizers.momentum, *args, **kwargs) @_add_doc(optimizers.rmsprop) class RMSProp(_NumPyroOptim): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super(RMSProp, self).__init__(optimizers.rmsprop, *args, **kwargs) @_add_doc(optimizers.rmsprop_momentum) class RMSPropMomentum(_NumPyroOptim): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super(RMSPropMomentum, self).__init__( optimizers.rmsprop_momentum, *args, **kwargs ) @@ -210,13 +214,13 @@ def __init__(self, *args, **kwargs): @_add_doc(optimizers.sgd) class SGD(_NumPyroOptim): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super(SGD, self).__init__(optimizers.sgd, *args, **kwargs) @_add_doc(optimizers.sm3) class SM3(_NumPyroOptim): - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: super(SM3, self).__init__(optimizers.sm3, *args, **kwargs) @@ -225,7 +229,11 @@ def __init__(self, *args, **kwargs): # and pass `unravel_fn` around. # When arbitrary pytree is supported in JAX, we can just simply use # identity functions for `init_fn` and `get_params`. -_MinimizeState = namedtuple("MinimizeState", ["flat_params", "unravel_fn"]) +class _MinimizeState(namedtuple("_MinimizeState", ["flat_params", "unravel_fn"])): + flat_params: ArrayLike + unravel_fn: Callable[[ArrayLike], _Params] + + register_pytree_node( _MinimizeState, lambda state: ((state.flat_params,), (state.unravel_fn,)), @@ -233,16 +241,24 @@ def __init__(self, *args, **kwargs): ) -def _minimize_wrapper(): - def init_fn(params): +def _minimize_wrapper() -> ( + tuple[ + Callable[[_Params], _MinimizeState], + Callable[[Any, Any, _MinimizeState], _MinimizeState], + Callable[[_MinimizeState], _Params], + ] +): + def init_fn(params: _Params) -> _MinimizeState: flat_params, unravel_fn = ravel_pytree(params) return _MinimizeState(flat_params, unravel_fn) - def update_fn(i, grad_tree, opt_state): + def update_fn( + i: ArrayLike, grad_tree: ArrayLike, opt_state: _MinimizeState + ) -> _MinimizeState: # we don't use update_fn in Minimize, so let it do nothing return opt_state - def get_params(opt_state): + def get_params(opt_state: _MinimizeState) -> _Params: flat_params, unravel_fn = opt_state return unravel_fn(flat_params) @@ -289,7 +305,7 @@ class Minimize(_NumPyroOptim): >>> assert_allclose(quantiles["b"], 3., atol=1e-3) """ - def __init__(self, method="BFGS", **kwargs): + def __init__(self, method="BFGS", **kwargs) -> None: super().__init__(_minimize_wrapper) self._method = method self._kwargs = kwargs @@ -298,8 +314,8 @@ def eval_and_update( self, fn: Callable[[Any], tuple], state: _IterOptState, - forward_mode_differentiation=False, - ): + forward_mode_differentiation: bool = False, + ) -> tuple[tuple[Any, None], _IterOptState]: i, (flat_params, unravel_fn) = state def loss_fn(x): @@ -333,17 +349,19 @@ def optax_to_numpyro(transformation) -> _NumPyroOptim: """ import optax - def init_fn(params): + def init_fn(params: _Params) -> tuple[_Params, Any]: opt_state = transformation.init(params) return params, opt_state - def update_fn(step, grads, state): + def update_fn( + step: ArrayLike, grads: ArrayLike, state: tuple[_Params, Any] + ) -> tuple[_Params, Any]: params, opt_state = state updates, opt_state = transformation.update(grads, opt_state, params) updated_params = optax.apply_updates(params, updates) return updated_params, opt_state - def get_params_fn(state): + def get_params_fn(state: tuple[_Params, Any]) -> _Params: params, _ = state return params diff --git a/pyproject.toml b/pyproject.toml index 660eb5e0e..28947c4ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -116,6 +116,7 @@ module = [ "numpyro.contrib.stochastic_support.*", "numpyro.diagnostics.*", "numpyro.handlers.*", + "numpyro.optim.*", "numpyro.primitives.*", "numpyro.patch.*", "numpyro.util.*",