-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add types optim module #1942
Add types optim module #1942
Conversation
@@ -31,12 +32,12 @@ | |||
"SM3", | |||
] | |||
|
|||
_Params = TypeVar("_Params") | |||
_Params = Any |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why use Any for _Params but not for OptState
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I leave _Params = TypeVar("_Params")
I get a MyPy error;
numpyro/optim.py:103: error: Type variable "numpyro.optim._Params" is unbound
This line is
def eval_and_update(
self,
fn: Callable[[Any], tuple],
state: _IterOptState,
forward_mode_differentiation: bool = False,
) -> tuple[tuple[Any, Any], _IterOptState]:
...
params: _Params = self.get_params(state). # <- HERE!
(out, aux), grads = _value_and_grad(
fn, x=params, forward_mode_differentiation=forward_mode_differentiation
)
The problem is that TypeVar
s are fine as long as they are part of the input and output. However, in our case, the variable params is created inside the functions and this is something MyPy does not like. It requires either a bound (a subset of possible types) or a concrete type.
This is one of the references that helped me clarify this issue https://stackoverflow.com/questions/68603585/mypy-why-does-typevar-not-work-without-bound-specified
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about just using `params = self.get_params(state)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately this does not work (it was my first try) :(
numpyro/optim.py:103: error: Need type annotation for "params" [var-annotated]
numpyro/optim.py:125: error: Type variable "numpyro.optim._Params" is unbound [valid-type]
The reason is because of the logic above. We can not use TypeVar internally :/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have made another attempt which I think serves our purpose in de835c6 I had to skip one type check in
def get_params(opt_state: _MinimizeState) -> _Params: # type: ignore[type-var]. # <- HERE
flat_params, unravel_fn = opt_state
return unravel_fn(flat_params)
Better?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we can use Any
for both to avoid confusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure! Changed in e24a7a2 :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM pending the above comment. Thanks, @juanitorduz!
@@ -31,12 +32,12 @@ | |||
"SM3", | |||
] | |||
|
|||
_Params = TypeVar("_Params") | |||
_Params = Any |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to specify a bound or a specific type and this is currently the "best" solution for pytrees, see jax-ml/jax#3340 (comment)
@@ -31,12 +32,12 @@ | |||
"SM3", | |||
] | |||
|
|||
_Params = TypeVar("_Params") | |||
_Params = Any |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I leave _Params = TypeVar("_Params")
I get a MyPy error;
numpyro/optim.py:103: error: Type variable "numpyro.optim._Params" is unbound
This line is
def eval_and_update(
self,
fn: Callable[[Any], tuple],
state: _IterOptState,
forward_mode_differentiation: bool = False,
) -> tuple[tuple[Any, Any], _IterOptState]:
...
params: _Params = self.get_params(state). # <- HERE!
(out, aux), grads = _value_and_grad(
fn, x=params, forward_mode_differentiation=forward_mode_differentiation
)
The problem is that TypeVar
s are fine as long as they are part of the input and output. However, in our case, the variable params is created inside the functions and this is something MyPy does not like. It requires either a bound (a subset of possible types) or a concrete type.
This is one of the references that helped me clarify this issue https://stackoverflow.com/questions/68603585/mypy-why-does-typevar-not-work-without-bound-specified
Ups! I had replied the comment but forgotten to hit the "reply button" 🙈 |
Add types optim module