Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

conda-based installation #189

Closed
ericmjl opened this issue Jan 3, 2019 · 71 comments · Fixed by #11888
Closed

conda-based installation #189

ericmjl opened this issue Jan 3, 2019 · 71 comments · Fixed by #11888
Labels
build contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)

Comments

@ericmjl
Copy link

ericmjl commented Jan 3, 2019

Putting this here and tagging myself @ericmjl so that I can remember this exists.

To get jax into the hands of data scientists and machine learning researchers, conda installation would be very useful. I will take a stab at this on conda-forge, and record my progress here.

@ericmjl ericmjl changed the title Conda installation conda-based installation Jan 3, 2019
@alexbw
Copy link
Contributor

alexbw commented Jan 3, 2019 via email

@ericmjl
Copy link
Author

ericmjl commented Jan 3, 2019

To start off, I tried getting jax onto my own personal channel on anaconda.org. If I can do this successfully, usually I am able to get it onto conda-forge with no problems.

Commands executed:

$ conda skeleton pypi jax
$ cd jax
$ conda build .

Everything builds correctly up till the point where the import tests run. jax imports jaxlib, and jaxlib needs to be on conda-forge and specified as a dependency of jax in order for the jax build process to work properly.

Unfortunately, I don't see the a tarball for jaxlib on PyPI. Perhaps that needs to go up first?

@alexbw
Copy link
Contributor

alexbw commented Jan 3, 2019 via email

@mattjj mattjj added the enhancement New feature or request label Jan 3, 2019
@ericmjl
Copy link
Author

ericmjl commented Jan 3, 2019

@alexbw I think those are the wheel URLs, not the tarballs. Does jaxlib have tarballs, or do they have to be built from source? If it's the latter, I might have to rope in some help from friends who are maintaining conda-forge.

@ericmjl
Copy link
Author

ericmjl commented Jan 3, 2019

@ocefpaf, one question for you - can we pull down Python wheels from a URL and use that pre-compiled wheel as part of a conda-forge-based recipe? Having had a night to sleep over this issue, it seems to me that building from source is going to be a painful thing to do on conda-forge, while having pre-compiled Python wheels installed into the correct location would be easier.

@ocefpaf
Copy link

ocefpaf commented Jan 3, 2019

@ocefpaf, one question for you - can we pull down Python wheels from a URL and use that pre-compiled wheel as part of a conda-forge-based recipe? Having had a night to sleep over this issue, it seems to me that building from source is going to be a painful thing to do on conda-forge, while having pre-compiled Python wheels installed into the correct location would be easier.

Even though prefer building from source we do "repacking" in cases like that.

Here is an example: https://github.com/conda-forge/flask-restplus-feedstock/blob/d41ecd6077ba51df75cb15a2b06e737bdc43f8d6/recipe/meta.yaml

@ericmjl
Copy link
Author

ericmjl commented Jan 3, 2019

@ocefpaf thanks for the response! Another dumb question, hope you don't mind - the jaxlib wheels are for macOS and Linux only: https://pypi.org/project/jaxlib/#files

I plan to "repackage" only the Python 3 wheels. Is there a way for us to specify which repackaged wheel to be downloaded, based on OS? Or is this out of scope for conda-forge?

@mattjj
Copy link
Collaborator

mattjj commented Jan 3, 2019

Thanks for driving this, Eric!

Just a question: how do we know that building will be painful on conda-forge? Our build process and build script are pretty simple if we can install and run bazel and meet the compiler toolchain requirements of TensorFlow. Since there are already conda packages for TF, we should be able to follow that setup, since the only thing JAX needs to compile from source is a sub-target inside TF. In other words, if we knew how to build TF on conda-forge, then we'd already know how to build what JAX needs, as it's a subset of TF.

@ericmjl
Copy link
Author

ericmjl commented Jan 3, 2019

@mattjj the bazel portion of installation takes quite a long while to run, which is what drives most of the "pain" from installation. If I am not mistaken, it may be a drain on free community resources to have to build jaxlib from scratch each time. @ocefpaf, do you have any input on this? For reference, it takes over 10 minutes on my home GPU tower to build from source.

@ocefpaf
Copy link

ocefpaf commented Jan 3, 2019

I plan to "repackage" only the Python 3 wheels. Is there a way for us to specify which repackaged wheel to be downloaded, based on OS? Or is this out of scope for conda-forge?

Yep. Just use the pre-processor selectors like in this example.

@ocefpaf, do you have any input on this? For reference, it takes over 10 minutes on my home GPU tower to build from source.

At the moment we, conda-forge, cannot afford long builds (>1 hour) and we do not have GPU support yet. However, we are experimenting with azure pipelines to be able to do long builds, and I believe we may get even some GPUs. More on this soon...

@ericmjl
Copy link
Author

ericmjl commented Jan 3, 2019

This is very helpful, thank you @ocefpaf!

@ericmjl
Copy link
Author

ericmjl commented Jan 6, 2019

Wanted to ensure that there was a cross-reference. jaxlib conda-forge PR is here: conda-forge/staged-recipes#7529

I mimicked the Tensorflow build recipe. Each time there's an update to jaxlib, the recipe, specifically build.sh and meta.yaml have to be updated.

No builds happen for Windows, as it is currently unavailable. To encourage Py3k adoption, I also intentionally did not include the Python 2 wheels in the build.

@ericmjl
Copy link
Author

ericmjl commented Jan 22, 2019

Looping back here about conda, guys. I tried submitting a PR for just jaxlib: conda-forge/staged-recipes#7529

However, it appears that there is an issue with the macOS build, which is only resolvable by building from source.

I think things will be cleaner if jaxlib is separated from jax. @mattjj, I remember this was on your roadmap before - am I remembering correctly, or am I mistaken about this?

@mattjj
Copy link
Collaborator

mattjj commented Jan 22, 2019

Yes, we had a plan to separate out jaxlib and call it xlapy, though it hasn't been a high priority compared to other work because there hasn't been a clear upside.

Is it possible to have separate conda packages for jax and jaxlib, without splitting the git repository? I ask mainly because we might not have time to dig into this for a while.

@ericmjl
Copy link
Author

ericmjl commented Jan 22, 2019

@mattjj I can try something - e.g. downloading an zip or tar archive of the whole jax repository, and then building from source.

Could you guys put up a tagged release on GitHub? That's generally better received by the conda-forge admins, and it'll give me a so-called "point release reference" that I can point the build recipe against, rather than always building against an ever-evolving master 😄.

@ericmjl
Copy link
Author

ericmjl commented Jan 22, 2019

Hah, looks like I was one step too fast for the conda-forge admins.

I just updated the conda-forge PR: conda-forge/staged-recipes#7529

Looks like there's a fix for the error being encountered before.

@mattjj
Copy link
Collaborator

mattjj commented Jan 22, 2019

Could you guys put up a tagged release on GitHub?

Is it viable to instead just clone a specific git commit? We might be able to do tagged releases and stuff, but I'd like to minimize the number of potential blockers for you.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Feb 8, 2019

Any updates? Anything we can help with?

It looks like other folks are eager for Conda packages too (#302)

Thanks!

@ericmjl
Copy link
Author

ericmjl commented Feb 8, 2019

@hawkinsp I'm still working on it! 😄 But yes, I've run into conda build issues at the moment. I stepped away from the conda-forge build for a little while (kwargs="work"), but you can track progress here: conda-forge/staged-recipes#7529

That PR, btw, is just a "copy the wheels over" PR to enable jax on CPU to be distributed by conda-forge. I think it will be more difficult to get jax + GPU over, unless there's a build process at Google you guys could use to release jax+jaxlib wheels targeted for various CUDAs? (If so, then I could re-attempt things.)

@hawkinsp
Copy link
Collaborator

hawkinsp commented Feb 8, 2019

Well, we build wheels for Python {2.7, 3.6, 3.7} x CUDA {9.0, 9.2, 10.0} already, which are the ones linked here:
https://github.com/google/jax#pip-installation

Our build script is open source: it is done by this script using a Docker container: https://github.com/google/jax/blob/master/build/build_jaxlib_wheels.sh

Ideally we could somehow build one wheel and distribute it for both Conda and Pypi, but if needs be we could also build separate Conda packages as well. We just need to know how to build one...

@ericmjl
Copy link
Author

ericmjl commented Feb 8, 2019

Well, we build wheels for Python {2.7, 3.6, 3.7} x CUDA {9.0, 9.2, 10.0} already, which are the ones linked here:
https://github.com/google/jax#pip-installation

Oh!!! I'm sorry, I missed that. I was somehow focused on just the CPU versions. My bad.

Ideally we could somehow build one wheel and distribute it for both Conda and Pypi, but if needs be we could also build separate Conda packages as well. We just need to know how to build one...

conda-forge build scripts are basically specified entirely by a YAML file, or as a YAML file + some other scripts. The latter is what I tried doing in that PR, specifically here to get jaxlib into conda-forge.

@ocefpaf, maybe you could provide some guidance to the jax team on where the docs for conda-forge recipes live? I've been doing packages for as long as conda-forge has been around, so it's kind of in my head now, but I know I've stumbled quite a few times because of the evolving infrastructure.

The part where jax's distribution doesn't match my mental model of packaging is that I need to build two things (jax and jaxlib) for jax to work (i.e. jax depends on jaxlib), but jaxlib lives in the same repository as jax; usually dependencies are other packages maintained by other people, so I just have to worry about my own package at hand. In other words, I've usually seen some separation between package X and dependency of X. Though, maybe I'm not as seasoned enough as a software person and have only encountered the simple cases ^_^.

@hawkinsp
Copy link
Collaborator

hawkinsp commented Feb 8, 2019

It's not out of the question we could split jaxlib (the mostly C++ part) into a separate repository from jax (the pure Python part). But it would be helpful to know what the constraints are before we start moving things around.

@Wildcarde
Copy link

Wildcarde commented Jun 20, 2019

Just to some input: Conda (more specifically anaconda) has it's own build env that it supplies so that the software can run in most places with relative ease. This includes it's own glibc allowing software built properly overtop of conda to run anywhere conda can be installed. When this environment isn't honored it can cause issues on systems anaconda otherwise works just fine on, centos 7 for example.
We are starting to get requests to install this software on our computational cluster which sure we could do from source but if it can be reliably installed off conda-forge which already manages the build env, and can bring in the cuda libraries on demand for gpu based code that is much easier to injest and maintain.

@ericmjl
Copy link
Author

ericmjl commented Jul 14, 2019

@alexbw @dougalm @mattjj @hawkinsp thanks to @ocefpaf, there is a conda package for jax and jaxlib on conda-forge now! He kindly worked on packaging what is currently distributed on PyPI while at the SciPy 2019 sprints.

It is currently only Python 3.7/CPU only. This is because it the recipe simply pulls down the py37 version (hard-coded). To the best of my knowledge, there’s no “elegant” way of distributing the CUDA-enabled packages using the same recipe given the current way jax and jaxlib are distributed on PyPI. If we could build jaxlib on conda-forge, that would greatly simplify the conda-based distribution story. I think @ocefpaf has more details than I do, as he knows the issues that he ran into trying to build jaxlib on conda-forge while at the sprints.

@lgsmith
Copy link

lgsmith commented Sep 24, 2019

Sorry if this is the wrong place to bring this up; I'm a fairly naive user just trying to see of some of the numpy fitting code I've written for my dissertation can be accelerated by putting in some jax in the right spots.

I have found that when installing from conda the import for jax currently fails with a ModuleNotFoundError: No module named 'fastcache'. I was able to rectify this by simply calling conda install fastcache in the correct env, which makes me think that the conda recipe may not have the correct dependencies.

I'm on ubuntu 18.04, and I'm using the most recent conda: 4.7.12. I can replicate the issue by creating a new env, installing jax from the conda-forge channel, then popping open the interpreter and trying to import jax. When I exit, use conda to install fastcache, then re-enter the interpreter, I can run blurbs of example code from the main github page with no issue.

I'm happy to make this a separate issue, but it seemed intimately related to the conversation here. Thanks for going to the trouble of trying to get this on conda.

@machineko
Copy link

machineko commented Aug 20, 2020

why is this still not officially supported?

Well, we have a small team, and so even though there's a lot of work worth doing, we have to choose what to prioritize.

I believe JAX works great with WSL, so if that works on your Windows setup you might want to give it a try.

@ericmjl I love your attitude! We've benefitted a huge amount from open-source contributions, and I hope we can get a lot more over time! We're all on the same team here, doing our best to push things forward.

WSL gpu ops are 2/3 times slower than on native linux (tf2 and jax) or on windows10 (tf2) so still would be nice to have w10 support :P

@hawkinsp hawkinsp added the contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. label Dec 2, 2020
@jotaf98
Copy link

jotaf98 commented Nov 7, 2021

Hi, I second the previous posters' opinion that no-frills conda install <whatever> with pre-packaged CUDA binaries on the 3 major OS's seems like a must for a library to win over data scientists.

Perhaps we've been spoiled by PyTorch?

How do they manage to make this so painless? It is a single conda command and no other instructions :)

After trying a lot of different approaches unsuccessfully, the one that did it was @cloudhan's excellent repo with Windows wheels: https://github.com/cloudhan/jax-windows-builder

You still need to get the right CUDA installed, and in my case I had to uninstall previous versions as they were still being wrongly used by JAX (ptax errors and warnings). Also it has specific jaxlib versions so I had to go through jax's previous versions' setup.py to see which one was compatible with the wheel's jaxlib version (jax 0.2.22 for jaxlib 0.1.72). Which did sour me a little bit, since the library seems very well put together.

@cloudhan
Copy link
Contributor

cloudhan commented Nov 8, 2021

and in my case I had to uninstall previous versions as they were still being wrongly used by JAX

@jotaf98 Actually you could workaround it by manually setup the command line environment variable. Remove all cuda related paths from PATH and CUDA_*, and reset them with the proper version.

I have that in my powershell profile:
https://gist.github.com/cloudhan/97db3c1e57895a09a80ec1f30c471cb3#file-profile-ps1-L309-L338

@jotaf98
Copy link

jotaf98 commented Nov 11, 2021

Ah good point, it's very likely that the profusion of paths did not help, and removing the previous ones would've been effective.

Still, a "You have the wrong CUDA version" message would have been a lot more informative than ptax errors at runtime. And going through the setup.py of all the versions to figure out the jax-jaxlib matching was when I realized this was not user-friendly at all :)

@ma-sadeghi
Copy link

Hey there, is there a GPU-compatible conda installation for jax yet? Thanks!

@proleu
Copy link

proleu commented Mar 31, 2022

To echo @ma-sadeghi it would be really great if there was a cuda build available on conda, pip installing leaves my conda environment very inconsistent and it feels like I have to do surgery everytime I want to upgrade something

@ngam
Copy link
Contributor

ngam commented May 14, 2022

Please see #10708 for deets, cuda-full builds should be ready very soon (pending more reviews)

@proleu
Copy link

proleu commented May 14, 2022

@ngam incredible, looking forward to testing this

@ngam
Copy link
Contributor

ngam commented May 24, 2022

Should be already up on anaconda.org, try installing jaxlib==0.3.10=*cuda* from conda-forge. Jax team, this issue is now resolved :)

(note CONDA_OVERRIDE_CUDA="11.2" is necessary only if you're on a machine without a GPU)

~$ CONDA_OVERRIDE_CUDA="11.2" mamba create -n jaxlibcuda jaxlib=*=*cuda* jax -c conda-forge

                  __    __    __    __
                 /  \  /  \  /  \  /  \
                /    \/    \/    \/    \
███████████████/  /██/  /██/  /██/  /████████████████████████
              /  / \   / \   / \   / \  \____
             /  /   \_/   \_/   \_/   \    o \__,
            / _/                       \_____/  `
            |/
        ███╗   ███╗ █████╗ ███╗   ███╗██████╗  █████╗
        ████╗ ████║██╔══██╗████╗ ████║██╔══██╗██╔══██╗
        ██╔████╔██║███████║██╔████╔██║██████╔╝███████║
        ██║╚██╔╝██║██╔══██║██║╚██╔╝██║██╔══██╗██╔══██║
        ██║ ╚═╝ ██║██║  ██║██║ ╚═╝ ██║██████╔╝██║  ██║
        ╚═╝     ╚═╝╚═╝  ╚═╝╚═╝     ╚═╝╚═════╝ ╚═╝  ╚═╝

        mamba (0.23.3) supported by @QuantStack

        GitHub:  https://github.com/mamba-org/mamba
        Twitter: https://twitter.com/QuantStack

█████████████████████████████████████████████████████████████


Looking for: ['jaxlib=[build=*cuda*]', 'jax']

conda-forge/noarch                                   8.3MB @   2.7MB/s  3.2s
conda-forge/linux-64                                23.2MB @   3.2MB/s  7.9s
Transaction

  Prefix: /home/ngam/.Mambaforge-Linux-x86_64/envs/jaxlibcuda

  Updating specs:

   - jaxlib=*[build=*cuda*]
   - jax


  Package                   Version  Build                   Channel                    Size
──────────────────────────────────────────────────────────────────────────────────────────────
  Install:
──────────────────────────────────────────────────────────────────────────────────────────────

  + _libgcc_mutex               0.1  conda_forge             conda-forge/linux-64        3kB
  + _openmp_mutex               4.5  2_gnu                   conda-forge/linux-64       24kB
  + abseil-cpp           20211102.0  h27087fc_1              conda-forge/linux-64        1MB
  + absl-py                   1.0.0  pyhd8ed1ab_0            conda-forge/noarch       Cached
  + bzip2                     1.0.8  h7f98852_4              conda-forge/linux-64     Cached
  + c-ares                   1.18.1  h7f98852_0              conda-forge/linux-64     Cached
  + ca-certificates     2022.5.18.1  ha878542_0              conda-forge/linux-64     Cached
  + cudatoolkit              11.7.0  hd8887f6_10             conda-forge/linux-64      872MB
  + cudnn                  8.2.1.32  h86fa8c9_0              conda-forge/linux-64      707MB
  + grpc-cpp                 1.46.3  h0b91f02_0              conda-forge/linux-64        5MB
  + jax                      0.3.13  pyhd8ed1ab_0            conda-forge/noarch        815kB
  + jaxlib                   0.3.10  cuda112py310h9def920_0  conda-forge/linux-64       66MB
  + ld_impl_linux-64         2.36.1  hea4e1c9_2              conda-forge/linux-64     Cached
  + libblas                   3.9.0  14_linux64_openblas     conda-forge/linux-64       13kB
  + libcblas                  3.9.0  14_linux64_openblas     conda-forge/linux-64       13kB
  + libffi                    3.4.2  h7f98852_5              conda-forge/linux-64     Cached
  + libgcc-ng                12.1.0  h8d9b700_16             conda-forge/linux-64     Cached
  + libgfortran-ng           12.1.0  h69a702a_16             conda-forge/linux-64       23kB
  + libgfortran5             12.1.0  hdcd56e2_16             conda-forge/linux-64     Cached
  + libgomp                  12.1.0  h8d9b700_16             conda-forge/linux-64     Cached
  + liblapack                 3.9.0  14_linux64_openblas     conda-forge/linux-64       13kB
  + libnsl                    2.0.0  h7f98852_0              conda-forge/linux-64     Cached
  + libopenblas              0.3.20  pthreads_h78a6416_0     conda-forge/linux-64     Cached
  + libprotobuf              3.20.1  h6239696_0              conda-forge/linux-64        3MB
  + libstdcxx-ng             12.1.0  ha89aaad_16             conda-forge/linux-64     Cached
  + libuuid                  2.32.1  h7f98852_1000           conda-forge/linux-64     Cached
  + libzlib                  1.2.11  h166bdaf_1014           conda-forge/linux-64     Cached
  + nccl                  2.12.12.1  h0800d71_0              conda-forge/linux-64      142MB
  + ncurses                     6.3  h27087fc_1              conda-forge/linux-64     Cached
  + numpy                    1.22.4  py310h4ef5377_0         conda-forge/linux-64        7MB
  + openssl                   3.0.3  h166bdaf_0              conda-forge/linux-64     Cached
  + opt_einsum                3.3.0  pyhd8ed1ab_1            conda-forge/noarch       Cached
  + pip                      22.1.1  pyhd8ed1ab_0            conda-forge/noarch       Cached
  + python                   3.10.4  h2660328_0_cpython      conda-forge/linux-64       30MB
  + python-flatbuffers          2.0  pyhd8ed1ab_0            conda-forge/noarch       Cached
  + python_abi                 3.10  2_cp310                 conda-forge/linux-64        4kB
  + re2                  2022.04.01  h27087fc_0              conda-forge/linux-64      217kB
  + readline                    8.1  h46c0cb4_0              conda-forge/linux-64     Cached
  + scipy                     1.8.1  py310h7612f91_0         conda-forge/linux-64     Cached
  + setuptools               62.3.2  py310hff52083_0         conda-forge/linux-64     Cached
  + six                      1.16.0  pyh6c4a22f_0            conda-forge/noarch       Cached
  + sqlite                   3.38.5  h4ff8645_0              conda-forge/linux-64     Cached
  + tk                       8.6.12  h27826a3_0              conda-forge/linux-64     Cached
  + typing_extensions         4.2.0  pyha770c72_1            conda-forge/noarch       Cached
  + tzdata                    2022a  h191b570_0              conda-forge/noarch       Cached
  + wheel                    0.37.1  pyhd8ed1ab_0            conda-forge/noarch       Cached
  + xz                        5.2.5  h516909a_1              conda-forge/linux-64     Cached
  + zlib                     1.2.11  h166bdaf_1014           conda-forge/linux-64     Cached

  Summary:

  Install: 48 packages

  Total download: 2GB

──────────────────────────────────────────────────────────────────────────────────────────────

Confirm changes: [Y/n] 

@ilemhadri
Copy link

ilemhadri commented May 24, 2022

Looks good!

TLDR: conda create -n <name_your_conda_env> jaxlib=*=*cuda* jax -c conda-forge creates a functional jax environment.

@falconair
Copy link

If this is a done deal, can the installation instructions on the readme.md be updated?

@ngam
Copy link
Contributor

ngam commented May 30, 2022

If this is a done deal, can the installation instructions on the readme.md be updated?

This is up to the jax team. I personally do not feel comfortable proposing a readme.md edit (it is unclear to me if the jax team want this publicized).

I believe they may be wary of adopting this in their docs as they may not have the bandwidth for troubleshooting potential issues that may come up if one installs via conda. The conda-forge activities are (mostly) community-driven volunteer efforts and volunteers (like me) come and go. So while it may be more convenient to install from conda-forge, one would need to know it is not the "official" way. Potentially (I haven't tested this yet), the conda-forge build will be faster than the PyPI build --- this is the case with tensorflow for example, on both CPUs and GPUs, and so I predict it will be the case with jax; someone should run some benchmarks :)

@cossio
Copy link

cossio commented Jul 16, 2022

It might be a bit confusing to have to navigate this issue to understand how to install jax through conda. Thanks @ilemhadri for the TLDR. So if I understand correctly, conda install jaxlib=*=*cuda* jax -c conda-forge should work.

Just one question. If one wants a CPU-only jax (say to install on a system that doesn't have a CUDA device), should I still use this command?

@cossio
Copy link

cossio commented Jul 16, 2022

I tried installing via (on a clean environment)

conda install jaxlib=*=*cuda* jax -c conda-forge

But eventually I get the following error:

>>> import jax.numpy as jnp
>>> from jax import grad, jit, vmap
>>> from jax import random
>>> key = random.PRNGKey(0)
2022-07-16 18:06:25.936045: W external/org_tensorflow/tensorflow/stream_executor/gpu/asm_compiler.cc:80] Couldn't get ptxas version string: INTERNAL: Couldn't invoke ptxas --version
2022-07-16 18:06:25.937165: F external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:449] ptxas returned an error during compilation of ptx to sass: 'INTERNAL: Failed to launch ptxas'  If the error message indicates that a file could not be written, please verify that sufficient filesystem space is provided.
Abandon

Not sure if this should be a separate issue?

@ngam
Copy link
Contributor

ngam commented Jul 16, 2022

@cossio if your machine doesn't have a GPU, please try conda install jaxlib jax -c conda-forge for the non-cuda version. If you want to be super certain you get the (latest) cpu version, try conda install jaxlib==*=*cpu* jax -c conda-forge

The error INTERNAL: Failed to launch ptxas makes sense if you don't have a GPU... but I would have expected something more like this:

>>> key = random.PRNGKey(0)
2022-07-16 23:09:37.072391: W external/org_tensorflow/tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcuda.so.1'; dlerror: libcuda.so.1: cannot open shared object file: No such file or directory
2022-07-16 23:09:37.072441: W external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:263] failed call to cuInit: UNKNOWN ERROR (303)
WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
>>>

@ngam
Copy link
Contributor

ngam commented Jul 16, 2022

The above instructions were strictly for the GPU version because the CPU versions have long been available. Generally, you should just conda install jaxlib jax -c conda-forge which will get you the most appropriate version for your machine, even a GPU version if you have a GPU available on the machine (but sometimes you ought to really request it if you're on a login node without a GPU, etc.)

@ngam
Copy link
Contributor

ngam commented Jul 16, 2022

Also, if people have issues with conda/conda-forge installations or getting errors from conda-forge packages, please remember you can get better help if you open an issue elsewhere:

@cossio
Copy link

cossio commented Jul 17, 2022

The computer where I got the INTERNAL: Failed to launch ptxas error has GPU.
I also note that if I install through pip (following the "official" instructions on the Readme), then this error does not occur.

@ngam
Copy link
Contributor

ngam commented Jul 17, 2022

The computer where I got the INTERNAL: Failed to launch ptxas error has GPU. I also note that if I install through pip (following the "official" instructions on the Readme), then this error does not occur.

Please open an issue here https://github.com/conda-forge/jaxlib-feedstock and include the output of conda info and conda list of the env where you install jaxlib and jax. There have been some issues with packaging ptxas in the past, but let's move the discussion to the jaxlib-feedstock to resolve this

@sudhakarsingh27 sudhakarsingh27 added NVIDIA GPU Issues specific to NVIDIA GPUs P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional) and removed P1 (soon) Assignee is working on this now, among other tasks. (Assignee required) labels Aug 10, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
build contributions welcome The JAX team has not prioritized work on this. Community contributions are welcome. enhancement New feature or request NVIDIA GPU Issues specific to NVIDIA GPUs P2 (eventual) This ought to be addressed, but has no schedule at the moment. (Assignee optional)
Projects
None yet
Development

Successfully merging a pull request may close this issue.