From 16cab677a7ed78bb0e86fe405d35712e4ff312a1 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Wed, 11 Sep 2024 15:35:25 -0700 Subject: [PATCH] Pin ml_dtypes==0.4.0 --- .github/workflows/test.yaml | 4 ++-- pyproject.toml | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 9d30d65..fd3f7b5 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -15,7 +15,7 @@ permissions: # TODO(jakevdp): add testing on macOS-11 and windows-2019 when grain supports them, # or alternatively run a subset of tests without grain. jobs: - built-latest: + build-latest: name: Latest packages (${{ matrix.os }} Python ${{ matrix.python-version }}) runs-on: ${{ matrix.os }} strategy: @@ -34,7 +34,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install -U jax flax grain optax orbax tensorflow tensorflow_datasets pytest pytest-xdist + pip install -U jax flax grain ml_dtypes optax orbax tensorflow tensorflow_datasets pytest pytest-xdist - name: Run tests run: | pytest -n auto jax_ml_stack diff --git a/pyproject.toml b/pyproject.toml index 099af66..0da8b6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "jax==0.4.31", "flax==0.8.5", "grain==0.2.0", + "ml_dtypes==0.4.0", "optax==0.2.3", "orbax==0.1.9", ]