Skip to content

Commit

Permalink
Add JAX and JAXOpt to library requirements
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jun 27, 2022
1 parent e9f8bbe commit 863744c
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 18 deletions.
2 changes: 0 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install .
pip install -r requirements-jax.txt
less requirements.txt | grep 'pytest\|chex' | xargs -i -t pip install {}
- name: Run the tests with pytest
run: |
Expand Down Expand Up @@ -62,7 +61,6 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install .
pip install -r requirements-jax.txt
less requirements.txt | grep 'pytest\|chex' | xargs -i -t pip install {}
- name: Run the benchmarks with pytest-benchmark
run: |
Expand Down
19 changes: 7 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,23 +26,18 @@ BlackJAX should appeal to those who:

### Installation

BlackJAX is written in pure Python but depends on XLA via JAX. Since the JAX
installation depends on your CUDA version BlackJAX does not list JAX as a
dependency. If you simply want to use JAX on CPU, install it with:

```python
pip install jax jaxlib
```

Follow [these instructions](https://github.com/google/jax#installation) to
install JAX with the relevant hardware acceleration support.

Then install BlackJAX
You can install BlackJAX using `pip`:

```bash
pip install blackjax
```

BlackJAX is written in pure Python but depends on XLA via JAX. By default, the
version of JAX that will be installed along with BlackJAX will make your code
run on CPU only. **If you want to use BlackJAX on GPU/TPU** we recommend you follow
[these instructions](https://github.com/google/jax#installation) to install JAX
with the relevant hardware acceleration support.

### Example

Let us look at a simple self-contained example sampling with NUTS:
Expand Down
3 changes: 0 additions & 3 deletions requirements-jax.txt

This file was deleted.

7 changes: 6 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@ def get_version(rel_path):
description="Flexible and fast inference in Python",
long_description=long_description,
packages=setuptools.find_packages(),
install_requires=["fastprogress>=0.2.0"],
install_requires=[
"fastprogress>=0.2.0",
"jax>=0.3.13",
"jaxlib>=0.3.13",
"jaxopt>=0.4.2",
],
long_description_content_type="text/markdown",
keywords="probabilistic machine learning bayesian statistics sampling algorithms",
license="Apache License 2.0",
Expand Down

0 comments on commit 863744c

Please sign in to comment.