Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add types optim module #1942

Merged
merged 5 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 48 additions & 30 deletions numpyro/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from collections import namedtuple
from collections.abc import Callable
from typing import Any, TypeVar
from typing import Any

import jax
from jax import jacfwd, lax, value_and_grad
Expand All @@ -18,6 +18,7 @@
import jax.numpy as jnp
from jax.scipy.optimize import minimize
from jax.tree_util import register_pytree_node
from jax.typing import ArrayLike

__all__ = [
"Adam",
Expand All @@ -31,12 +32,12 @@
"SM3",
]

_Params = TypeVar("_Params")
_OptState = TypeVar("_OptState")
_IterOptState = tuple[int, _OptState]
_Params = Any
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to specify a bound or a specific type and this is currently the "best" solution for pytrees, see jax-ml/jax#3340 (comment)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use Any for _Params but not for OptState

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I leave _Params = TypeVar("_Params") I get a MyPy error;

numpyro/optim.py:103: error: Type variable "numpyro.optim._Params" is unbound

This line is

https://github.com/juanitorduz/numpyro/blob/cd364d84308f253f0d0a251bdb0f36daaa38cfd4/numpyro/optim.py#L103-L107

    def eval_and_update(
        self,
        fn: Callable[[Any], tuple],
        state: _IterOptState,
        forward_mode_differentiation: bool = False,
    ) -> tuple[tuple[Any, Any], _IterOptState]:
        ...
        params: _Params = self.get_params(state).   # <- HERE!
        (out, aux), grads = _value_and_grad(
            fn, x=params, forward_mode_differentiation=forward_mode_differentiation
        )

The problem is that TypeVars are fine as long as they are part of the input and output. However, in our case, the variable params is created inside the functions and this is something MyPy does not like. It requires either a bound (a subset of possible types) or a concrete type.

This is one of the references that helped me clarify this issue https://stackoverflow.com/questions/68603585/mypy-why-does-typevar-not-work-without-bound-specified

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just using `params = self.get_params(state)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this does not work (it was my first try) :(

numpyro/optim.py:103: error: Need type annotation for "params"  [var-annotated]
numpyro/optim.py:125: error: Type variable "numpyro.optim._Params" is unbound  [valid-type]

The reason is because of the logic above. We can not use TypeVar internally :/

Copy link
Contributor Author

@juanitorduz juanitorduz Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made another attempt which I think serves our purpose in de835c6 I had to skip one type check in

    def get_params(opt_state: _MinimizeState) -> _Params:  # type: ignore[type-var]. # <- HERE
        flat_params, unravel_fn = opt_state
        return unravel_fn(flat_params)

Better?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can use Any for both to avoid confusion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! Changed in e24a7a2 :)

_OptState = Any
_IterOptState = tuple[ArrayLike, _OptState]


def _value_and_grad(f, x, forward_mode_differentiation=False):
def _value_and_grad(f, x, forward_mode_differentiation=False) -> tuple:
if forward_mode_differentiation:

def _wrapper(x):
Expand All @@ -51,6 +52,9 @@ def _wrapper(x):

class _NumPyroOptim(object):
def __init__(self, optim_fn: Callable, *args, **kwargs) -> None:
self.init_fn: Callable[[_Params], _IterOptState]
self.update_fn: Callable[[ArrayLike, _Params, _OptState], _OptState]
self.get_params_fn: Callable[[_OptState], _Params]
self.init_fn, self.update_fn, self.get_params_fn = optim_fn(*args, **kwargs)

def init(self, params: _Params) -> _IterOptState:
Expand Down Expand Up @@ -80,7 +84,7 @@ def eval_and_update(
fn: Callable[[Any], tuple],
state: _IterOptState,
forward_mode_differentiation: bool = False,
):
) -> tuple[tuple[Any, Any], _IterOptState]:
"""
Performs an optimization step for the objective function `fn`.
For most optimizers, the update is performed based on the gradient
Expand All @@ -96,7 +100,7 @@ def eval_and_update(
:param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation.
:return: a pair of the output of objective function and the new optimizer state.
"""
params = self.get_params(state)
params: _Params = self.get_params(state)
(out, aux), grads = _value_and_grad(
fn, x=params, forward_mode_differentiation=forward_mode_differentiation
)
Expand All @@ -107,7 +111,7 @@ def eval_and_stable_update(
fn: Callable[[Any], tuple],
state: _IterOptState,
forward_mode_differentiation: bool = False,
):
) -> tuple[tuple[Any, Any], _IterOptState]:
"""
Like :meth:`eval_and_update` but when the value of the objective function
or the gradients are not finite, we will not update the input `state`
Expand All @@ -118,7 +122,7 @@ def eval_and_stable_update(
:param forward_mode_differentiation: boolean flag indicating whether to use forward mode differentiation.
:return: a pair of the output of objective function and the new optimizer state.
"""
params = self.get_params(state)
params: _Params = self.get_params(state)
(out, aux), grads = _value_and_grad(
fn, x=params, forward_mode_differentiation=forward_mode_differentiation
)
Expand All @@ -141,7 +145,7 @@ def get_params(self, state: _IterOptState) -> _Params:
return self.get_params_fn(opt_state)


def _add_doc(fn):
def _add_doc(fn) -> Callable[[Any], Any]:
def _wrapped(cls):
cls.__doc__ = "Wrapper class for the JAX optimizer: :func:`~jax.example_libraries.optimizers.{}`".format(
fn.__name__
Expand All @@ -153,7 +157,7 @@ def _wrapped(cls):

@_add_doc(optimizers.adam)
class Adam(_NumPyroOptim):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super(Adam, self).__init__(optimizers.adam, *args, **kwargs)


Expand All @@ -170,11 +174,11 @@ class ClippedAdam(_NumPyroOptim):
https://arxiv.org/abs/1412.6980
"""

def __init__(self, *args, clip_norm=10.0, **kwargs):
def __init__(self, *args, clip_norm: float = 10.0, **kwargs) -> None:
self.clip_norm = clip_norm
super(ClippedAdam, self).__init__(optimizers.adam, *args, **kwargs)

def update(self, g, state):
def update(self, g: _Params, state: _IterOptState) -> _IterOptState:
i, opt_state = state
# clip norm
g = jax.tree.map(lambda g_: jnp.clip(g_, -self.clip_norm, self.clip_norm), g)
Expand All @@ -184,39 +188,39 @@ def update(self, g, state):

@_add_doc(optimizers.adagrad)
class Adagrad(_NumPyroOptim):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super(Adagrad, self).__init__(optimizers.adagrad, *args, **kwargs)


@_add_doc(optimizers.momentum)
class Momentum(_NumPyroOptim):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super(Momentum, self).__init__(optimizers.momentum, *args, **kwargs)


@_add_doc(optimizers.rmsprop)
class RMSProp(_NumPyroOptim):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super(RMSProp, self).__init__(optimizers.rmsprop, *args, **kwargs)


@_add_doc(optimizers.rmsprop_momentum)
class RMSPropMomentum(_NumPyroOptim):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super(RMSPropMomentum, self).__init__(
optimizers.rmsprop_momentum, *args, **kwargs
)


@_add_doc(optimizers.sgd)
class SGD(_NumPyroOptim):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super(SGD, self).__init__(optimizers.sgd, *args, **kwargs)


@_add_doc(optimizers.sm3)
class SM3(_NumPyroOptim):
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs) -> None:
super(SM3, self).__init__(optimizers.sm3, *args, **kwargs)


Expand All @@ -225,24 +229,36 @@ def __init__(self, *args, **kwargs):
# and pass `unravel_fn` around.
# When arbitrary pytree is supported in JAX, we can just simply use
# identity functions for `init_fn` and `get_params`.
_MinimizeState = namedtuple("MinimizeState", ["flat_params", "unravel_fn"])
class _MinimizeState(namedtuple("_MinimizeState", ["flat_params", "unravel_fn"])):
flat_params: ArrayLike
unravel_fn: Callable[[ArrayLike], _Params]


register_pytree_node(
_MinimizeState,
lambda state: ((state.flat_params,), (state.unravel_fn,)),
lambda data, xs: _MinimizeState(xs[0], data[0]),
)


def _minimize_wrapper():
def init_fn(params):
def _minimize_wrapper() -> (
tuple[
Callable[[_Params], _MinimizeState],
Callable[[Any, Any, _MinimizeState], _MinimizeState],
Callable[[_MinimizeState], _Params],
]
):
def init_fn(params: _Params) -> _MinimizeState:
flat_params, unravel_fn = ravel_pytree(params)
return _MinimizeState(flat_params, unravel_fn)

def update_fn(i, grad_tree, opt_state):
def update_fn(
i: ArrayLike, grad_tree: ArrayLike, opt_state: _MinimizeState
) -> _MinimizeState:
# we don't use update_fn in Minimize, so let it do nothing
return opt_state

def get_params(opt_state):
def get_params(opt_state: _MinimizeState) -> _Params:
flat_params, unravel_fn = opt_state
return unravel_fn(flat_params)

Expand Down Expand Up @@ -289,7 +305,7 @@ class Minimize(_NumPyroOptim):
>>> assert_allclose(quantiles["b"], 3., atol=1e-3)
"""

def __init__(self, method="BFGS", **kwargs):
def __init__(self, method="BFGS", **kwargs) -> None:
super().__init__(_minimize_wrapper)
self._method = method
self._kwargs = kwargs
Expand All @@ -298,8 +314,8 @@ def eval_and_update(
self,
fn: Callable[[Any], tuple],
state: _IterOptState,
forward_mode_differentiation=False,
):
forward_mode_differentiation: bool = False,
) -> tuple[tuple[Any, None], _IterOptState]:
i, (flat_params, unravel_fn) = state

def loss_fn(x):
Expand Down Expand Up @@ -333,17 +349,19 @@ def optax_to_numpyro(transformation) -> _NumPyroOptim:
"""
import optax

def init_fn(params):
def init_fn(params: _Params) -> tuple[_Params, Any]:
opt_state = transformation.init(params)
return params, opt_state

def update_fn(step, grads, state):
def update_fn(
step: ArrayLike, grads: ArrayLike, state: tuple[_Params, Any]
) -> tuple[_Params, Any]:
params, opt_state = state
updates, opt_state = transformation.update(grads, opt_state, params)
updated_params = optax.apply_updates(params, updates)
return updated_params, opt_state

def get_params_fn(state):
def get_params_fn(state: tuple[_Params, Any]) -> _Params:
params, _ = state
return params

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ module = [
"numpyro.contrib.stochastic_support.*",
"numpyro.diagnostics.*",
"numpyro.handlers.*",
"numpyro.optim.*",
"numpyro.primitives.*",
"numpyro.patch.*",
"numpyro.util.*",
Expand Down
Loading