From e92b989142bee3e2dc774cca183c06de714e8828 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Thu, 31 Oct 2024 14:55:51 -0400 Subject: [PATCH] ci: install GPU JAX in GPU CI Signed-off-by: Jinzhe Zeng --- .github/workflows/test_cuda.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 6bf4c8552f..996a1bcff0 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -47,7 +47,7 @@ jobs: && sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3 if: false # skip as we use nvidia image - run: python -m pip install -U uv - - run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.5.0" + - run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.5.0" "jax[cuda12]" - run: | export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)') @@ -61,6 +61,8 @@ jobs: env: NUM_WORKERS: 0 CUDA_VISIBLE_DEVICES: 0 + # See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html + XLA_PYTHON_CLIENT_PREALLOCATE: false - name: Download libtorch run: | wget https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu124.zip -O libtorch.zip