diff --git a/.github/workflows/test_cuda.yml b/.github/workflows/test_cuda.yml index 6bf4c8552f..996a1bcff0 100644 --- a/.github/workflows/test_cuda.yml +++ b/.github/workflows/test_cuda.yml @@ -47,7 +47,7 @@ jobs: && sudo apt-get -y install cuda-12-3 libcudnn8=8.9.5.*-1+cuda12.3 if: false # skip as we use nvidia image - run: python -m pip install -U uv - - run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.5.0" + - run: source/install/uv_with_retry.sh pip install --system "tensorflow~=2.18.0rc2" "torch~=2.5.0" "jax[cuda12]" - run: | export PYTORCH_ROOT=$(python -c 'import torch;print(torch.__path__[0])') export TENSORFLOW_ROOT=$(python -c 'import importlib,pathlib;print(pathlib.Path(importlib.util.find_spec("tensorflow").origin).parent)') @@ -61,6 +61,8 @@ jobs: env: NUM_WORKERS: 0 CUDA_VISIBLE_DEVICES: 0 + # See https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html + XLA_PYTHON_CLIENT_PREALLOCATE: false - name: Download libtorch run: | wget https://download.pytorch.org/libtorch/cu124/libtorch-cxx11-abi-shared-with-deps-2.5.0%2Bcu124.zip -O libtorch.zip