jax-ai-stack
packages:jax==0.4.38
↗️ flax==0.10.2
ml_dtypes==0.4.0
optax==0.2.4
orbax-checkpoint==0.11.0
↗️ orbax-export==0.0.6
jax-ai-stack[tfds]
packages:tensorflow==2.18.0
tensorflow_datasets==4.9.7
jax-ai-stack[grain]
packages:grain==0.2.3
↗️