Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add types optim module #1942

Merged
merged 5 commits into from
Jan 3, 2025
Merged

Add types optim module #1942

merged 5 commits into from
Jan 3, 2025

Conversation

juanitorduz
Copy link
Contributor

Add types optim module

@@ -31,12 +32,12 @@
"SM3",
]

_Params = TypeVar("_Params")
_Params = Any
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why use Any for _Params but not for OptState

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I leave _Params = TypeVar("_Params") I get a MyPy error;

numpyro/optim.py:103: error: Type variable "numpyro.optim._Params" is unbound

This line is

https://github.com/juanitorduz/numpyro/blob/cd364d84308f253f0d0a251bdb0f36daaa38cfd4/numpyro/optim.py#L103-L107

    def eval_and_update(
        self,
        fn: Callable[[Any], tuple],
        state: _IterOptState,
        forward_mode_differentiation: bool = False,
    ) -> tuple[tuple[Any, Any], _IterOptState]:
        ...
        params: _Params = self.get_params(state).   # <- HERE!
        (out, aux), grads = _value_and_grad(
            fn, x=params, forward_mode_differentiation=forward_mode_differentiation
        )

The problem is that TypeVars are fine as long as they are part of the input and output. However, in our case, the variable params is created inside the functions and this is something MyPy does not like. It requires either a bound (a subset of possible types) or a concrete type.

This is one of the references that helped me clarify this issue https://stackoverflow.com/questions/68603585/mypy-why-does-typevar-not-work-without-bound-specified

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about just using `params = self.get_params(state)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately this does not work (it was my first try) :(

numpyro/optim.py:103: error: Need type annotation for "params"  [var-annotated]
numpyro/optim.py:125: error: Type variable "numpyro.optim._Params" is unbound  [valid-type]

The reason is because of the logic above. We can not use TypeVar internally :/

Copy link
Contributor Author

@juanitorduz juanitorduz Jan 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have made another attempt which I think serves our purpose in de835c6 I had to skip one type check in

    def get_params(opt_state: _MinimizeState) -> _Params:  # type: ignore[type-var]. # <- HERE
        flat_params, unravel_fn = opt_state
        return unravel_fn(flat_params)

Better?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can use Any for both to avoid confusion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure! Changed in e24a7a2 :)

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending the above comment. Thanks, @juanitorduz!

@@ -31,12 +32,12 @@
"SM3",
]

_Params = TypeVar("_Params")
_Params = Any
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to specify a bound or a specific type and this is currently the "best" solution for pytrees, see jax-ml/jax#3340 (comment)

@@ -31,12 +32,12 @@
"SM3",
]

_Params = TypeVar("_Params")
_Params = Any
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I leave _Params = TypeVar("_Params") I get a MyPy error;

numpyro/optim.py:103: error: Type variable "numpyro.optim._Params" is unbound

This line is

https://github.com/juanitorduz/numpyro/blob/cd364d84308f253f0d0a251bdb0f36daaa38cfd4/numpyro/optim.py#L103-L107

    def eval_and_update(
        self,
        fn: Callable[[Any], tuple],
        state: _IterOptState,
        forward_mode_differentiation: bool = False,
    ) -> tuple[tuple[Any, Any], _IterOptState]:
        ...
        params: _Params = self.get_params(state).   # <- HERE!
        (out, aux), grads = _value_and_grad(
            fn, x=params, forward_mode_differentiation=forward_mode_differentiation
        )

The problem is that TypeVars are fine as long as they are part of the input and output. However, in our case, the variable params is created inside the functions and this is something MyPy does not like. It requires either a bound (a subset of possible types) or a concrete type.

This is one of the references that helped me clarify this issue https://stackoverflow.com/questions/68603585/mypy-why-does-typevar-not-work-without-bound-specified

@juanitorduz
Copy link
Contributor Author

Ups! I had replied the comment but forgotten to hit the "reply button" 🙈

@fehiepsi fehiepsi merged commit 6ae76ea into pyro-ppl:master Jan 3, 2025
10 checks passed
@juanitorduz juanitorduz deleted the types-optim branch January 3, 2025 13:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants