JAX scientific+ML ecosystem:
I've written a lot of numerical JAX and PyTorch, now used in diverse applications across science (simulation of black holes, soil moisture, ...) and ML (large language models, large protein models, ...).
Some of the libraries I would highlight:
-
Equinox: elegant neural networks.
-
Diffrax: numerical ODE/SDE solvers.
-
jaxtyping: shape/dtype annotations for arrays. (Also supports PyTorch etc, despite the name!)
-
Lineax: linear/least-squares solvers.
-
Optimistix: root finding, least squares, etc.
-
sympy2jax: optimise your symbolic expressions via gradient descent!
Me:
I currently do ML for protein engineering (lead optimization) at Cradle Bio. I also hold an honorary lectureship at Imperial College London. I previously worked at Google X, and did my PhD at the University of Oxford.
My interests include neural ODEs, numerical methods, protein language models, and more broadly scientific computing and scientific machine learning. These days I am interested in scientific machine learning, and specifically the application of ML to unsolved problems in biology! I am also known for having strong opinions on the importance of good software development :)
Other links:
- Twitter:
- Bluesky:
- Google scholar: here
- Personal website: kidger.site
- Neural ODE/SDE textbook: arXiv/2202.02435