Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect resolution for JAX #2369

Closed
1 task done
ethanluoyc opened this issue Nov 4, 2023 · 2 comments · Fixed by #2371
Closed
1 task done

Incorrect resolution for JAX #2369

ethanluoyc opened this issue Nov 4, 2023 · 2 comments · Fixed by #2371
Assignees
Labels
🐛 bug Something isn't working 🧩 dependency resolution Resolution failures

Comments

@ethanluoyc
Copy link
Contributor

ethanluoyc commented Nov 4, 2023

  • I have searched the issue tracker and believe that this is not a duplicate.

I am trying to specify JAX as a dependency for my project, but currently it is not resolved correctly by PDM.

For some additional context. JAX depends on a native dependency jaxlib for correct operation. This requires the user to specify the extra (e.g., jax[cpu], jax[cuda]) to install the correct jaxlib variant, where the CPU variant uses a version and the GPU wheels use a local version identifier to differentiate from the CPU one (and the user needs to use JAX's find link).

Steps to reproduce

The following pyproject resolves incorrectly

[project]
name = "pdm-example"
version = "0.1.0"
description = ""
authors = [
    {name = "Yicheng Luo", email = "[email protected]"},
]
dependencies = [
    "jax==0.4.17",
]
requires-python = ">=3.10,<3.11"
readme = "README.md"
license = {text = "MIT"}

[project.optional-dependencies]
cpu = ["jax[cpu]"]
cuda = ["jax[cuda12_pip]"]

[build-system]
requires = ["pdm-backend"]
build-backend = "pdm.backend"

# [[tool.pdm.source]]
# name = "jax_cuda"
# url = "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
# verify_ssl = true
# type = "find_links"

[tool.pdm.resolution]
respect-source-order = true

If I run pdm lock -G cpu and the lockfile looks like

# This file is @generated by PDM.
# It is not intended for manual editing.

[metadata]
groups = ["default", "cpu"]
strategy = ["cross_platform"]
lock_version = "4.4"
content_hash = "sha256:c7c792eac140bf630ef85a5d708263e9598ace33bba32675d6aade1ac6a881b4"

[[package]]
name = "jax"
version = "0.4.17"
requires_python = ">=3.9"
summary = "Differentiate, compile, and transform Numpy code."
dependencies = [
    "ml-dtypes>=0.2.0",
    "numpy>=1.22",
    "opt-einsum",
    "scipy>=1.7",
]
files = [
    {file = "jax-0.4.17-py3-none-any.whl", hash = "sha256:c3ab72ea2f1c5d8ccf2561e79f6562fb2964629f3e55b3ac1c11c48b64c20336"},
    {file = "jax-0.4.17.tar.gz", hash = "sha256:d7508a69e87835f534cb07a2f21d79cc1cb8c4cfdcf7fb010927267ef7355f1d"},
]

[[package]]
name = "jax"
version = "0.4.17"
extras = ["cpu"]
requires_python = ">=3.9"
summary = "Differentiate, compile, and transform Numpy code."
dependencies = [
    "jax==0.4.17",
    "jaxlib==0.4.17",
]
files = [
    {file = "jax-0.4.17-py3-none-any.whl", hash = "sha256:c3ab72ea2f1c5d8ccf2561e79f6562fb2964629f3e55b3ac1c11c48b64c20336"},
    {file = "jax-0.4.17.tar.gz", hash = "sha256:d7508a69e87835f534cb07a2f21d79cc1cb8c4cfdcf7fb010927267ef7355f1d"},
]

[[package]]
name = "jaxlib"
version = "0.4.17"
requires_python = ">=3.9"
summary = "XLA library for JAX"
dependencies = [
    "ml-dtypes>=0.2.0",
    "numpy>=1.22",
    "scipy>=1.7",
]
files = [
    {file = "jaxlib-0.4.17-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:d4be1ac4bf1be1ae1cd8f5f4da414a6d0de8de36cf2effdb5758d4d677896078"},
    {file = "jaxlib-0.4.17-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:392c779f902c43e1a0af49159daffef9b5af952aba001463f98cf95a59ef17ff"},
    {file = "jaxlib-0.4.17-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:160fce68b82a79a6c522652e8dd9a10aac9c00d1599cb7e166671ad909aa139e"},
    {file = "jaxlib-0.4.17-cp310-cp310-win_amd64.whl", hash = "sha256:61b3788c6cfe46f307e6e67d4a942de72cf34711ff349f4f11500cdf6dc67199"},
]

[[package]]
name = "ml-dtypes"
version = "0.3.1"
requires_python = ">=3.9"
summary = ""
dependencies = [
    "numpy>1.20",
    "numpy>=1.21.2; python_version > \"3.9\"",
]
files = [
    {file = "ml_dtypes-0.3.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:510d249a91face47211762eb294d6fe64f325356b965fb6388c1bf51bd339267"},
    {file = "ml_dtypes-0.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f83ff080df8910c0f987f615b03e4f8198638e0c00c6e679ea8892dda909763b"},
    {file = "ml_dtypes-0.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fcae2c69715410d96906e1dfe8f017d9f78a0d10e0df91aae52e91f51fdfe45e"},
    {file = "ml_dtypes-0.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:da274599e4950a9b488d21571061f49a185537cc77f2d3f8121151d58a9e9f16"},
    {file = "ml_dtypes-0.3.1.tar.gz", hash = "sha256:60778f99194b4c4f36ba42da200b35ef851ce4d4af698aaf70f5b91fe70fc611"},
]

[[package]]
name = "numpy"
version = "1.26.1"
requires_python = "<3.13,>=3.9"
summary = "Fundamental package for array computing in Python"
files = [
    {file = "numpy-1.26.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:82e871307a6331b5f09efda3c22e03c095d957f04bf6bc1804f30048d0e5e7af"},
    {file = "numpy-1.26.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cdd9ec98f0063d93baeb01aad472a1a0840dee302842a2746a7a8e92968f9575"},
    {file = "numpy-1.26.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d78f269e0c4fd365fc2992c00353e4530d274ba68f15e968d8bc3c69ce5f5244"},
    {file = "numpy-1.26.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8ab9163ca8aeb7fd32fe93866490654d2f7dda4e61bc6297bf72ce07fdc02f67"},
    {file = "numpy-1.26.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:78ca54b2f9daffa5f323f34cdf21e1d9779a54073f0018a3094ab907938331a2"},
    {file = "numpy-1.26.1-cp310-cp310-win32.whl", hash = "sha256:d1cfc92db6af1fd37a7bb58e55c8383b4aa1ba23d012bdbba26b4bcca45ac297"},
    {file = "numpy-1.26.1-cp310-cp310-win_amd64.whl", hash = "sha256:d2984cb6caaf05294b8466966627e80bf6c7afd273279077679cb010acb0e5ab"},
    {file = "numpy-1.26.1-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:06934e1a22c54636a059215d6da99e23286424f316fddd979f5071093b648668"},
    {file = "numpy-1.26.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76ff661a867d9272cd2a99eed002470f46dbe0943a5ffd140f49be84f68ffc42"},
    {file = "numpy-1.26.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:6965888d65d2848e8768824ca8288db0a81263c1efccec881cb35a0d805fcd2f"},
    {file = "numpy-1.26.1.tar.gz", hash = "sha256:c8c6c72d4a9f831f328efb1312642a1cafafaa88981d9ab76368d50d07d93cbe"},
]

[[package]]
name = "opt-einsum"
version = "3.3.0"
requires_python = ">=3.5"
summary = "Optimizing numpys einsum function"
dependencies = [
    "numpy>=1.7",
]
files = [
    {file = "opt_einsum-3.3.0-py3-none-any.whl", hash = "sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147"},
    {file = "opt_einsum-3.3.0.tar.gz", hash = "sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549"},
]

[[package]]
name = "scipy"
version = "1.11.3"
requires_python = "<3.13,>=3.9"
summary = "Fundamental algorithms for scientific computing in Python"
dependencies = [
    "numpy<1.28.0,>=1.21.6",
]
files = [
    {file = "scipy-1.11.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:370f569c57e1d888304052c18e58f4a927338eafdaef78613c685ca2ea0d1fa0"},
    {file = "scipy-1.11.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:9885e3e4f13b2bd44aaf2a1a6390a11add9f48d5295f7a592393ceb8991577a3"},
    {file = "scipy-1.11.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e04aa19acc324a1a076abb4035dabe9b64badb19f76ad9c798bde39d41025cdc"},
    {file = "scipy-1.11.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e1a8a4657673bfae1e05e1e1d6e94b0cabe5ed0c7c144c8aa7b7dbb774ce5c1"},
    {file = "scipy-1.11.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:7abda0e62ef00cde826d441485e2e32fe737bdddee3324e35c0e01dee65e2a88"},
    {file = "scipy-1.11.3-cp310-cp310-win_amd64.whl", hash = "sha256:033c3fd95d55012dd1148b201b72ae854d5086d25e7c316ec9850de4fe776929"},
    {file = "scipy-1.11.3.tar.gz", hash = "sha256:bba4d955f54edd61899776bad459bf7326e14b9fa1c552181f0479cc60a568cd"},
]

For pdm export -G cpu I get

# This file is @generated by PDM.
# Please do not edit it manually.

jax==0.4.17
jax==0.4.17
ml-dtypes==0.3.1
numpy==1.26.1
opt-einsum==3.3.0
scipy==1.11.3

Actual behavior

There are a few issues with the resolution above

  1. Duplicate entries of JAX. I saw duplicate entries in the exported requirements.
  2. When I selectively export the CPU group, jax should include jaxlib as a dependency, but it is currently missing from the exported requirements. This also causes problems with pdm install as jaxlib is also not installed.

This is not expected.

Expected behavior

I expect that when using the CPU group to install dependencies, jaxlib should be included as a dependnecy.

Environment Information

# Paste the output of `pdm info && pdm info --env` below:
PDM version:
  2.10.0
Python Interpreter:
  /workspace/.venv/bin/python (3.10)
Project Root:
  /workspace
Local Packages:

{
  "implementation_name": "cpython",
  "implementation_version": "3.10.13",
  "os_name": "posix",
  "platform_machine": "x86_64",
  "platform_release": "5.15.0-87-generic",
  "platform_system": "Linux",
  "platform_version": "#97~20.04.1-Ubuntu SMP Thu Oct 5 08:25:28 UTC 2023",
  "python_full_version": "3.10.13",
  "platform_python_implementation": "CPython",
  "python_version": "3.10",
  "sys_platform": "linux"
}

Updates

I tried a few things, it seems that if you also pin the version in extra then things work, but I find that to be surprising.

@ethanluoyc ethanluoyc added the 🐛 bug Something isn't working label Nov 4, 2023
@frostming frostming added the 🧩 dependency resolution Resolution failures label Nov 6, 2023
@frostming frostming self-assigned this Nov 6, 2023
frostming added a commit that referenced this issue Nov 6, 2023
@mthiboust
Copy link

Not sure if it is related, but I get this error with "jax[cuda12_pip]==0.4.20" (I also tried the lastest 2.10.1 version containing your fix):

$ pdm install
Lock file does not exist
Updating the lock file...
🔒 Lock failed
Unable to find a resolution for jaxlib
because of the following conflicts:
  jaxlib==0.4.20+cuda12.cudnn89 (from [email protected])
To fix this, you could loosen the dependency version constraints in pyproject.toml. See https://pdm-project.org/latest/usage/dependency/#solve-the-locking-failure for more details.
See /tmp/pdm-lock-dfwb8ku8.log for detailed debug log.
[ResolutionImpossible]: Unable to find a resolution
Add '-v' to see the detailed traceback

In this toy example, jax[cuda12_pip] is the only dependency in my project, so I don't understand how it can be in conflict with its own dependencies.

My current workaround is to manually specify the necessary dependencies of jax[cuda12_pip] as listed here: https://github.com/google/jax/blob/c5d6df4557c1f0bf543a1315fca3bf203bc2201d/setup.py#L128

Environment information

# Paste the output of `pdm info && pdm info --env` below:
PDM version:
  2.10.1
Python Interpreter:
  /home/myuser/path/.venv/bin/python (3.11)
Project Root:
  /home/myuser/path/
Local Packages:
  
{
  "implementation_name": "cpython",
  "implementation_version": "3.11.6",
  "os_name": "posix",
  "platform_machine": "x86_64",
  "platform_release": "6.5.0-10-generic",
  "platform_system": "Linux",
  "platform_version": "#10-Ubuntu SMP PREEMPT_DYNAMIC Fri Oct 13 13:49:38 UTC 2023",
  "python_full_version": "3.11.6",
  "platform_python_implementation": "CPython",
  "python_version": "3.11",
  "sys_platform": "linux"
}

@wbthomason
Copy link

I'm experiencing the same issue as @mthiboust on PDM 2.11. The workaround of adding the other dependencies explicitly did not seem to work for me; instead, I resorted to adding the jaxlib dependency directly via its URL, although this is unfortunately brittle.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🐛 bug Something isn't working 🧩 dependency resolution Resolution failures
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants