Skip to content

Commit

Permalink
attempt keep TypeVar
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Jan 2, 2025
1 parent cd364d8 commit de835c6
Showing 1 changed file with 13 additions and 9 deletions.
22 changes: 13 additions & 9 deletions numpyro/optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from collections import namedtuple
from collections.abc import Callable
from typing import Any, TypeVar
from typing import Any, Generic, TypeVar

import jax
from jax import jacfwd, lax, value_and_grad
Expand All @@ -32,7 +32,7 @@
"SM3",
]

_Params = Any
_Params = TypeVar("_Params")
_OptState = TypeVar("_OptState")
_IterOptState = tuple[ArrayLike, _OptState]

Expand All @@ -50,7 +50,7 @@ def _wrapper(x):
return value_and_grad(f, has_aux=True)(x)


class _NumPyroOptim(object):
class _NumPyroOptim(Generic[_Params, _OptState]):
def __init__(self, optim_fn: Callable, *args, **kwargs) -> None:
self.init_fn: Callable[[_Params], _IterOptState]
self.update_fn: Callable[[ArrayLike, _Params, _OptState], _OptState]
Expand Down Expand Up @@ -229,7 +229,11 @@ def __init__(self, *args, **kwargs) -> None:
# and pass `unravel_fn` around.
# When arbitrary pytree is supported in JAX, we can just simply use
# identity functions for `init_fn` and `get_params`.
_MinimizeState = namedtuple("_MinimizeState", ["flat_params", "unravel_fn"])
class _MinimizeState(namedtuple("_MinimizeState", ["flat_params", "unravel_fn"])):
flat_params: ArrayLike
unravel_fn: Callable[[ArrayLike], _Params]


register_pytree_node(
_MinimizeState,
lambda state: ((state.flat_params,), (state.unravel_fn,)),
Expand All @@ -239,9 +243,9 @@ def __init__(self, *args, **kwargs) -> None:

def _minimize_wrapper() -> (
tuple[
Callable[[Any], _MinimizeState],
Callable[[Any, Any, Any], Any],
Callable[[Any], _Params],
Callable[[_Params], _MinimizeState],
Callable[[Any, Any, _MinimizeState], _MinimizeState],
Callable[[_MinimizeState], _Params],
]
):
def init_fn(params: _Params) -> _MinimizeState:
Expand All @@ -254,7 +258,7 @@ def update_fn(
# we don't use update_fn in Minimize, so let it do nothing
return opt_state

def get_params(opt_state: _MinimizeState) -> _Params:
def get_params(opt_state: _MinimizeState) -> _Params: # type: ignore[type-var]
flat_params, unravel_fn = opt_state
return unravel_fn(flat_params)

Expand Down Expand Up @@ -301,7 +305,7 @@ class Minimize(_NumPyroOptim):
>>> assert_allclose(quantiles["b"], 3., atol=1e-3)
"""

def __init__(self, method="BFGS", **kwargs):
def __init__(self, method="BFGS", **kwargs) -> None:
super().__init__(_minimize_wrapper)
self._method = method
self._kwargs = kwargs
Expand Down

0 comments on commit de835c6

Please sign in to comment.