Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump CI to python 3.10 #1890

Merged
merged 4 commits into from
Oct 17, 2024
Merged

Bump CI to python 3.10 #1890

merged 4 commits into from
Oct 17, 2024

Conversation

fehiepsi
Copy link
Member

@fehiepsi fehiepsi commented Oct 16, 2024

Jax currently dropped support for Python 3.9, so our CI didn't catch the issues introduced in new jax versions.

Fixes #1876

@tare
Copy link
Contributor

tare commented Oct 17, 2024

BinomialLogits should be fixed too

$ python --version
Python 3.11.8
$ pip list
Package                Version
---------------------- -----------
absl-py                2.1.0
certifi                2024.8.30
charset-normalizer     3.4.0
chex                   0.1.87
cloudpickle            3.1.0
contourpy              1.3.0
cycler                 0.12.1
decorator              5.1.1
dm-haiku               0.0.13
dm-tree                0.1.8
etils                  1.10.0
flax                   0.10.0
fonttools              4.54.1
fsspec                 2024.9.0
funsor                 0.4.5
gast                   0.6.0
graphviz               0.20.3
humanize               4.11.0
idna                   3.10
importlib-metadata     4.13.0
importlib_resources    6.4.5
iniconfig              2.0.0
jax                    0.4.34
jaxlib                 0.4.34
jaxns                  2.4.8
jaxopt                 0.8.3
jmp                    0.0.4
joblib                 1.4.2
kiwisolver             1.4.7
makefun                1.15.6
markdown-it-py         3.0.0
matplotlib             3.9.2
mdurl                  0.1.2
ml_dtypes              0.5.0
msgpack                1.1.0
multipledispatch       1.0.0
nest-asyncio           1.6.0
numpy                  2.1.2
numpyro                0.15.3
opt_einsum             3.4.0
optax                  0.2.3
orbax-checkpoint       0.7.0
packaging              24.1
pillow                 11.0.0
pip                    24.0
pluggy                 1.5.0
protobuf               5.28.2
Pygments               2.18.0
pylab-sdk              1.7.2
pyparsing              3.2.0
pyro-api               0.1.2
pytest                 8.3.3
python-dateutil        2.9.0.post0
PyYAML                 6.0.2
requests               2.32.3
rich                   13.9.2
ruff                   0.6.9
scikit-learn           1.5.2
scipy                  1.14.1
setuptools             65.5.0
six                    1.16.0
tabulate               0.9.0
tensorflow-probability 0.24.0
tensorstore            0.1.66
threadpoolctl          3.5.0
toolz                  1.0.0
tqdm                   4.66.5
typing_extensions      4.12.2
urllib3                2.2.3
zipp                   3.20.2
$ pytest test/test_distributions.py::test_log_prob_gradient
⋮
FAILED test/test_distributions.py::test_log_prob_gradient[BinomialLogits-<lambda>-params158] - TypeError: Called multiply with a float0 array. float0s do not support any ...
============ 1 failed, 202 passed, 13 skipped in 183.74s (0:03:03) =============

@fehiepsi
Copy link
Member Author

Thanks, I need to see why CI does not trigger the error.

@fehiepsi
Copy link
Member Author

Oh, jax dropped support for python 3.10

@fehiepsi fehiepsi changed the title Fix grad of discrete log prob Bump CI to python 3.10 Oct 17, 2024
Copy link
Member

@ordabayevy ordabayevy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is my understanding correct that you are always casting to float in just case the value is an integer?

Can there be a case when jnp.result_type(float) returns float32 but the value's dtype is float64 and it gets downcasted?

@fehiepsi
Copy link
Member Author

casting to float in just case the value is an integer

that's right.

it gets downcasted?

That might not be possible. In float32 world, I don't think we can create float64 arrays. But we can cast float32 to float64.

Copy link
Member

@ordabayevy ordabayevy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@fehiepsi fehiepsi merged commit d867c54 into pyro-ppl:master Oct 17, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Test failure due to float0 array in gradient computation for BernoulliProbs distribution in jax==0.4.34
4 participants