Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

module 'jax.random' has no attribute 'KeyArray' #1186

Closed
spatts14 opened this issue Feb 10, 2024 · 20 comments
Closed

module 'jax.random' has no attribute 'KeyArray' #1186

spatts14 opened this issue Feb 10, 2024 · 20 comments
Labels
bug Something isn't working

Comments

@spatts14
Copy link

spatts14 commented Feb 10, 2024

I cannot import scvelo as not compatible with most updated jax (0.4.24) and jaxlib (0.4.24).
What version do I need to have to be compatible?
...

import scvelo as scv
Error output
AttributeError: module 'jax.random' has no attribute 'KeyArray'
Versions
Version: 0.3.1
@spatts14 spatts14 added the bug Something isn't working label Feb 10, 2024
@christophechu
Copy link

same question

@Zethson
Copy link
Member

Zethson commented Feb 11, 2024

Install scvi-tools from main. We'll make a new release soon that fixes this.

@Zethson Zethson closed this as completed Feb 11, 2024
@hvgogogo
Copy link

Install scvi-tools from main. We'll make a new release soon that fixes this.

could you indicate how to install scvi-tools from main, thanks so much.

@Zethson
Copy link
Member

Zethson commented Feb 12, 2024

@hvgogogo

  1. you git clone the repository and cd into it
  2. pip install -U .

If this is unclear, please consults your favorite search engine or LLM

@christophechu
Copy link

@hvgogogo @Zethson
just using jax == 0.4.19
its the error from scvi

@hvgogogo
Copy link

@hvgogogo @Zethson just using jax == 0.4.19 its the error from scvi

Thanks so much, It works.
just a reminder to the following ones, you need downgrade the pip install jaxlib==0.4.19 too.

@christophechu
Copy link

SCVI will upgrade to 1.1.0. They will fix this problem soon.

@spatts14
Copy link
Author

spatts14 commented Feb 13, 2024

The workaround works on my desktop however, I'm trying to run it on the HPC and did the install per instructed and am running into errors.

I installed doing the following:

       git clone https://github.com/scverse/scvi-tools.git 
       pip install -U scvi-tools
       pip install jax==0.4.19

I am getting the following error after running import scvelo:


---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[3], line 1
----> 1 import scvelo

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/__init__.py:5
      2 from anndata import AnnData
      3 from scanpy import read, read_loom
----> 5 from scvelo import datasets, logging
      6 from scvelo import plotting as pl
      7 from scvelo import preprocessing as pp

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/datasets/__init__.py:1
----> 1 from ._datasets import (
      2     bonemarrow,
      3     dentategyrus,
      4     dentategyrus_lamanno,
      5     forebrain,
      6     gastrulation,
      7     gastrulation_e75,
      8     gastrulation_erythroid,
      9     pancreas,
     10     pancreatic_endocrinogenesis,
     11     pbmc68k,
     12     toy_data,
     13 )
     14 from ._simulate import simulation
     16 __all__ = [
     17     "bonemarrow",
     18     "dentategyrus",
   (...)
     28     "toy_data",
     29 ]

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/datasets/_datasets.py:10
      6 import pandas as pd
      8 from scanpy import read
---> 10 from scvelo.core import cleanup
     11 from scvelo.read_load import load
     13 url_datadir = "https://github.com/theislab/scvelo_notebooks/raw/master/"

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/core/__init__.py:1
----> 1 from ._anndata import (
      2     clean_obs_names,
      3     cleanup,
      4     get_df,
      5     get_initial_size,
      6     get_modality,
      7     get_size,
      8     make_dense,
      9     make_sparse,
     10     merge,
     11     set_initial_size,
     12     set_modality,
     13     show_proportions,
     14 )
     15 from ._arithmetic import clipped_log, invert, multiply, prod_sum, sum
     16 from ._linear_models import LinearRegression

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/core/_anndata.py:15
     11 from scipy.sparse import csr_matrix, issparse, spmatrix
     13 from anndata import AnnData
---> 15 from scvelo import logging as logg
     16 from ._arithmetic import sum
     17 from ._utils import deprecated_arg_names

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/logging.py:12
      8 from packaging.version import parse
     10 from anndata.logging import get_memory_usage
---> 12 from scvelo import settings
     14 _VERBOSITY_LEVELS_FROM_STRINGS = {"error": 0, "warn": 1, "info": 2, "hint": 3}
     17 def info(*args, **kwargs):

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/settings.py:91
     83 """See set_figure_params.
     84 """
     87 # --------------------------------------------------------------------------------
     88 # Functions
     89 # --------------------------------------------------------------------------------
---> 91 warnings.filterwarnings("ignore", category=cbook.mplDeprecation)
     94 # default matplotlib 2.0 palette slightly modified.
     95 vega_10 = list(map(colors.to_hex, cm.tab10.colors))

AttributeError: module 'matplotlib.cbook' has no attribute 'mplDeprecation'

I've tried making a new conda environment and reinstalling per described, but am running into the same errors. I am also am getting an error with import scvi

/rds/general/user/sep22/home/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvi/_settings.py:63: UserWarning: Since v1.0.0, scvi-tools no longer uses a random seed by default. Run `scvi.settings.seed = 0` to reproduce results from previous versions.
  self.seed = seed
/rds/general/user/sep22/home/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvi/_settings.py:70: UserWarning: Setting `dl_pin_memory_gpu_training` is deprecated in v1.0 and will be removed in v1.1. Please pass in `pin_memory` to the data loaders instead.
  self.dl_pin_memory_gpu_training = (
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[4], line 1
----> 1 import scvi

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvi/__init__.py:11
      8 from ._settings import settings
     10 # this import needs to come after prior imports to prevent circular import
---> 11 from . import autotune, data, model, external, utils, criticism
     13 from importlib.metadata import version
     15 package_name = "scvi-tools"

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvi/autotune/__init__.py:1
----> 1 from ._manager import TuneAnalysis, TunerManager
      2 from ._tuner import ModelTuner
      3 from ._types import Tunable, TunableMixin

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvi/autotune/_manager.py:11
      9 import lightning.pytorch as pl
     10 import rich
---> 11 from chex import dataclass
     13 try:
     14     import ray

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/chex/__init__.py:17
      1 # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Chex: Testing made fun, in JAX!"""
---> 17 from chex._src.asserts import assert_axis_dimension
     18 from chex._src.asserts import assert_axis_dimension_comparator
     19 from chex._src.asserts import assert_axis_dimension_gt

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/chex/_src/asserts.py:26
     23 import unittest
     24 from unittest import mock
---> 26 from chex._src import asserts_internal as _ai
     27 from chex._src import pytypes
     28 import jax

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/chex/_src/asserts_internal.py:34
     31 from typing import Any, Sequence, Union, Callable, List, Optional, Set, Tuple, Type
     33 from absl import logging
---> 34 from chex._src import pytypes
     35 import jax
     36 from jax.experimental import checkify

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/chex/_src/pytypes.py:54
     52 Numeric = Union[Array, Scalar]
     53 Shape = jax.core.Shape
---> 54 PRNGKey = jax.random.KeyArray
     55 PyTreeDef = jax.tree_util.PyTreeDef
     56 Device = jax.Device

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/jax/_src/deprecations.py:53, in deprecation_getattr.<locals>.getattr(name)
     51   warnings.warn(message, DeprecationWarning, stacklevel=2)
     52   return fn
---> 53 raise AttributeError(f"module {module!r} has no attribute {name!r}")

AttributeError: module 'jax.random' has no attribute 'KeyArray'

I then tried upgrading

pip install --upgrade matplotlib scvelo

and now Im getting this error:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[4], line 1
----> 1 import scvelo

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/__init__.py:6
      3 from scanpy import read, read_loom
      5 from scvelo import datasets, logging
----> 6 from scvelo import plotting as pl
      7 from scvelo import preprocessing as pp
      8 from scvelo import settings

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/plotting/__init__.py:3
      1 from scanpy.plotting import paga_compare, rank_genes_groups
----> 3 from .gridspec import gridspec
      4 from .heatmap import heatmap
      5 from .paga import paga

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/plotting/gridspec.py:6
      3 import matplotlib.pyplot as pl
      5 # todo: auto-complete and docs wrapper
----> 6 from .scatter import scatter
      7 from .utils import get_figure_params, hist
      8 from .velocity_embedding import velocity_embedding

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/plotting/scatter.py:16
     14 from scvelo.preprocessing.neighbors import get_connectivities
     15 from .docs import doc_params, doc_scatter
---> 16 from .utils import (
     17     default_basis,
     18     default_color,
     19     default_color_map,
     20     default_legend_loc,
     21     default_size,
     22     default_xkey,
     23     default_ykey,
     24     get_ax,
     25     get_components,
     26     get_figure_params,
     27     get_kwargs,
     28     get_obs_vector,
     29     get_value_counts,
     30     gets_vals_from_color_gradients,
     31     groups_to_bool,
     32     interpret_colorkey,
     33     is_categorical,
     34     is_int,
     35     is_list,
     36     is_list_of_int,
     37     is_list_of_list,
     38     is_list_of_str,
     39     make_dense,
     40     plot_density,
     41     plot_linfit,
     42     plot_outline,
     43     plot_polyfit,
     44     plot_rug,
     45     plot_velocity_fits,
     46     rgb_custom_colormap,
     47     savefig_or_show,
     48     set_colorbar,
     49     set_colors_for_categorical_obs,
     50     set_label,
     51     set_legend,
     52     set_margin,
     53     set_title,
     54     to_list,
     55     to_val,
     56     to_valid_bases_list,
     57     update_axes,
     58 )
     61 @doc_params(scatter=doc_scatter)
     62 def scatter(
     63     adata=None,
   (...)
    122     **kwargs,
    123 ):
    124     """Scatter plot along observations or variables axes.
    125 
    126     Arguments:
   (...)
    138     If `show==False` a `matplotlib.Axis`
    139     """

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/plotting/utils.py:24
     22 from scvelo import logging as logg
     23 from scvelo import settings
---> 24 from scvelo.tools.utils import strings_to_categoricals
     25 from . import palettes
     27 """helper functions"""

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/tools/__init__.py:14
      4 from ._em_model_core import (
      5     align_dynamics,
      6     differential_kinetic_test,
   (...)
     11     recover_latent_time,
     12 )
     13 from ._steady_state_model import SecondOrderSteadyStateModel, SteadyStateModel
---> 14 from ._vi_model import VELOVI
     15 from .paga import paga
     16 from .rank_velocity_genes import rank_velocity_genes, velocity_clusters

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvelo/tools/_vi_model.py:14
     11 from scipy.stats import ttest_ind
     13 from anndata import AnnData
---> 14 from scvi.data import AnnDataManager
     15 from scvi.data.fields import LayerField
     16 from scvi.dataloaders import DataSplitter

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvi/__init__.py:11
      8 from ._settings import settings
     10 # this import needs to come after prior imports to prevent circular import
---> 11 from . import autotune, data, model, external, utils, criticism
     13 from importlib.metadata import version
     15 package_name = "scvi-tools"

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvi/autotune/__init__.py:1
----> 1 from ._manager import TuneAnalysis, TunerManager
      2 from ._tuner import ModelTuner
      3 from ._types import Tunable, TunableMixin

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/scvi/autotune/_manager.py:11
      9 import lightning.pytorch as pl
     10 import rich
---> 11 from chex import dataclass
     13 try:
     14     import ray

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/chex/__init__.py:17
      1 # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Chex: Testing made fun, in JAX!"""
---> 17 from chex._src.asserts import assert_axis_dimension
     18 from chex._src.asserts import assert_axis_dimension_comparator
     19 from chex._src.asserts import assert_axis_dimension_gt

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/chex/_src/asserts.py:26
     23 import unittest
     24 from unittest import mock
---> 26 from chex._src import asserts_internal as _ai
     27 from chex._src import pytypes
     28 import jax

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/chex/_src/asserts_internal.py:34
     31 from typing import Any, Sequence, Union, Callable, List, Optional, Set, Tuple, Type
     33 from absl import logging
---> 34 from chex._src import pytypes
     35 import jax
     36 from jax.experimental import checkify

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/chex/_src/pytypes.py:19
     15 """Type definitions to use for type annotations."""
     17 from typing import Any, Iterable, Mapping, Union
---> 19 import jax
     20 import jax.numpy as jnp
     21 import numpy as np

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/jax/__init__.py:39
     34 del _cloud_tpu_init
     36 # Confusingly there are two things named "config": the module and the class.
     37 # We want the exported object to be the class, so we first import the module
     38 # to make sure a later import doesn't overwrite the class.
---> 39 from jax import config as _config_module
     40 del _config_module
     42 # Force early import, allowing use of `jax.core` after importing `jax`.

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/jax/config.py:17
      1 # Copyright 2018 The JAX Authors.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     14 
     15 # TODO(phawkins): fix users of this alias and delete this file.
---> 17 from jax._src.config import config  # noqa: F401

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/jax/_src/config.py:27
     24 import threading
     25 from typing import Any, Callable, Generic, NamedTuple, NoReturn, Optional, TypeVar
---> 27 from jax._src import lib
     28 from jax._src.lib import jax_jit
     29 from jax._src.lib import transfer_guard_lib

File ~/anaconda3/envs/scVelo/lib/python3.11/site-packages/jax/_src/lib/__init__.py:75
     70   return _jaxlib_version
     73 version_str = jaxlib.version.__version__
     74 version = check_jaxlib_version(
---> 75   jax_version=jax.version.__version__,
     76   jaxlib_version=jaxlib.version.__version__,
     77   minimum_jaxlib_version=jax.version._minimum_jaxlib_version)
     79 # Before importing any C compiled modules from jaxlib, first import the CPU
     80 # feature guard module to verify that jaxlib was compiled in a way that only
     81 # uses instructions that are present on this machine.
     82 import jaxlib.cpu_feature_guard as cpu_feature_guard

AttributeError: partially initialized module 'jax' has no attribute 'version' (most likely due to a circular import)

Any thoughts?

@aterceros
Copy link

Hi,
Having the same issue, downgraded to jax==0.4.19 but still having the "module 'jax.random' has no attribute 'KeyArray'" error...any hints?

@christophechu
Copy link

@aterceros @spatts14
upgrade scvi to 1.1.0 may solve this issue.

@bio-la
Copy link

bio-la commented Feb 14, 2024

has anyone tested if this issue with MultiVI is also sorted in scvi 1.1.0?
https://discourse.scverse.org/t/error-when-training-model-on-m3-max-mps/1896
multivi breaks with 1.0.4 (no issues with scvi,totalvi)

@spatts14
Copy link
Author

Solution:

I updated to scvi-tools 1.1.1 and upgraded pandas.

Current versions:
scanpy 1.9.8
scvelo 0.3.1
scvi-tools 1.1.1

Now running on the HPC.

Of note, I use Parse data. The Parse website says to use pandas==1.5.3, however I needed to update.

@aterceros
Copy link

Hi,
Thank you for the comments, I upgraded to scvi-tools 1.1.0post2 (which seems to be the latest, no 1.1.1. version), but error still persisting. @spatts14 what version of pandas do you have?
Thanks a lot!

@spatts14
Copy link
Author

I have pandas== 2.2.0, but I also have version scvi-tools 1.1.1 (released yesterday). I would suggest setting up a new environment and install from the main

@epignatelli
Copy link

epignatelli commented Mar 7, 2024

Just bumped here by total chance while searching for the error.

This might help:
https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html#type-annotations-for-prng-keys

KeyArray was deprecated and has now been removed, use jax.Array instead.

@dvagbear
Copy link

dvagbear commented Mar 12, 2024

Just bumped here by total chance while searching for the error.

This might help: https://jax.readthedocs.io/en/latest/jep/9263-typed-keys.html#type-annotations-for-prng-keys

it is a quick resolution that change jax,random.KeyArray into jax.Array in chex/_src/pytypes.py as the error hint

@Xxianna
Copy link

Xxianna commented Mar 26, 2024

Bug still exists in 1.1.2 😭
with jax==0.4.25
Use PRNGKey = jax.random.PRNGKey but not PRNGKey = jax.random.KeyArray😀
It will probably work, I see they have the same parameters, according to the jax documentation provided upstairs. At least my init worked. The change occurred in
\site-packages\chex\_src\pytypes.py

@WeilerP
Copy link
Member

WeilerP commented Mar 26, 2024

You can now pip install scvelo (scvelo>=0.3.2) without scvi and jax as a dependency.

@littlewhitesea
Copy link

I met with a same problem and solved it through the following command line from stackoverflow.

pip install "jax[cuda12_pip]==0.4.23" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

@canergen
Copy link

canergen commented May 3, 2024

Just seeing this. The dependency of Chex is removed in scVI-tools 1.1.0 and the error is not supposed to happen. If your still facing it, I would recommend setting up a new environment or otherwise report directly at scVI-tools.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests