diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 5e835ae2a..dc32876dd 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -33,7 +33,6 @@ jobs: run: | python -m pip install --upgrade pip pip install . - pip install -r requirements-jax.txt less requirements.txt | grep 'pytest\|chex' | xargs -i -t pip install {} - name: Run the tests with pytest run: | @@ -62,7 +61,6 @@ jobs: run: | python -m pip install --upgrade pip pip install . - pip install -r requirements-jax.txt less requirements.txt | grep 'pytest\|chex' | xargs -i -t pip install {} - name: Run the benchmarks with pytest-benchmark run: | diff --git a/README.md b/README.md index d9a9c6c52..3eaf558c7 100644 --- a/README.md +++ b/README.md @@ -26,23 +26,18 @@ BlackJAX should appeal to those who: ### Installation -BlackJAX is written in pure Python but depends on XLA via JAX. Since the JAX -installation depends on your CUDA version BlackJAX does not list JAX as a -dependency. If you simply want to use JAX on CPU, install it with: - -```python -pip install jax jaxlib -``` - -Follow [these instructions](https://github.com/google/jax#installation) to -install JAX with the relevant hardware acceleration support. - -Then install BlackJAX +You can install BlackJAX using `pip`: ```bash pip install blackjax ``` +BlackJAX is written in pure Python but depends on XLA via JAX. By default, the +version of JAX that will be installed along with BlackJAX will make your code +run on CPU only. **If you want to use BlackJAX on GPU/TPU** we recommend you follow +[these instructions](https://github.com/google/jax#installation) to install JAX +with the relevant hardware acceleration support. + ### Example Let us look at a simple self-contained example sampling with NUTS: diff --git a/requirements-jax.txt b/requirements-jax.txt deleted file mode 100644 index a78935198..000000000 --- a/requirements-jax.txt +++ /dev/null @@ -1,3 +0,0 @@ -jax -jaxlib -jaxopt diff --git a/setup.py b/setup.py index db6cf1ea2..d9cf5ee6e 100644 --- a/setup.py +++ b/setup.py @@ -44,7 +44,12 @@ def get_version(rel_path): description="Flexible and fast inference in Python", long_description=long_description, packages=setuptools.find_packages(), - install_requires=["fastprogress>=0.2.0"], + install_requires=[ + "fastprogress>=0.2.0", + "jax>=0.3.13", + "jaxlib>=0.3.13", + "jaxopt>=0.4.2", + ], long_description_content_type="text/markdown", keywords="probabilistic machine learning bayesian statistics sampling algorithms", license="Apache License 2.0",