Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Generics #97

Open
thomasahle opened this issue Jul 3, 2024 · 1 comment
Open

Use Generics #97

thomasahle opened this issue Jul 3, 2024 · 1 comment

Comments

@thomasahle
Copy link

Haliax is an amazing library, but sometimes it still falls back to using comments for types like this:

    proj_q: hnn.Linear  # [Embed] -> [Head, Key]
    proj_k: hnn.Linear  # [Embed] -> [Head, Key]
    proj_v: hnn.Linear  # [Embed] -> [Head, Key]
    proj_answer: hnn.Linear  # output projection from [Head, Key] -> [Embed]

Using Python's Generics it would be possible to "formalize" this as:

    proj_q: hnn.Linear[Embed, tuple[Head, Key]]
    proj_k: hnn.Linear[Embed, tuple[Head, Key]]
    proj_v: hnn.Linear[Embed, tuple[Head, Key]]
    proj_answer: hnn.Linear[tuple[Head, Key], Embed]

And the mypy type checker could be used to catch accidentally assignments of the wrong type to these variables.
Is this something that has already been considered?

@dlwh
Copy link
Member

dlwh commented Jul 3, 2024

Thanks yeah I've thought about it but didn't want to take it on... Have you seen https://docs.kidger.site/jaxtyping/ ? I'd be very open to some kind of interop or implementing what you describe. I think where it gets tricky is for dealing with variadic arrays. e.g. What's the type signature of Linear.__call__, given that we want it to "replace the input with the output" and work for an arbitrary number of dimensions.

If we just wanted to add it as purely documentary (i.e. no checking), I'd be open to exploring that as a first step.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants