Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax-metal support appears broken #864

Open
jder opened this issue Jan 23, 2025 · 0 comments
Open

jax-metal support appears broken #864

jder opened this issue Jan 23, 2025 · 0 comments

Comments

@jder
Copy link

jder commented Jan 23, 2025

On macOS 15.1.1 using python 3.10.14 with an M2 Max processor, running through the installation instructions & sample test command doesn't work for me. This with with a checkout of latest levanter (237851b) with a fresh venv. I tried a few variants:

Following the instructions directly, pip install jax-metal==0.0.5 followed by pip install -e . produces a dependency conflict:

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
jax-metal 0.0.5 requires jax==0.4.20, but you have jax 0.5.0 which is incompatible.
jax-metal 0.0.5 requires jaxlib==0.4.20, but you have jaxlib 0.5.0 which is incompatible.

Doing them in one go ie pip install jax-metal==0.0.5 -e . installs successfully but then immediately errors out with what looks like incompatible numpy versions when trying to run the demo command python -m levanter.main.train_lm --config config/gpt2_nano.yaml, presumably due to downgrading of other dependencies given the old jax-metal version:

full error
A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.

If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.

Traceback (most recent call last):  File "/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 187, in _run_module_as_main
    mod_name, mod_spec, code = _get_module_details(mod_name, _Error)
  File "/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 110, in _get_module_details
    __import__(pkg_name)
  File "/Users/jder/oa/levanter/src/levanter/__init__.py", line 1, in <module>
    import levanter.checkpoint as checkpoint
  File "/Users/jder/oa/levanter/src/levanter/checkpoint.py", line 15, in <module>
    import equinox
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/equinox/__init__.py", line 3, in <module>
    from . import debug as debug, internal as internal, nn as nn
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/equinox/debug/__init__.py", line 1, in <module>
    from ._announce_transform import announce_transform as announce_transform
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/equinox/debug/_announce_transform.py", line 4, in <module>
    import jax
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/__init__.py", line 39, in <module>
    from jax import config as _config_module
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/config.py", line 15, in <module>
    from jax._src.config import config as _deprecated_config  # noqa: F401
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/config.py", line 28, in <module>
    from jax._src import lib
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 90, in <module>
    import jaxlib.xla_client as xla_client
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jaxlib/xla_client.py", line 29, in <module>
    from . import xla_extension as _xla
AttributeError: _ARRAY_API not found
Traceback (most recent call last):
  File "/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 187, in _run_module_as_main
    mod_name, mod_spec, code = _get_module_details(mod_name, _Error)
  File "/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 110, in _get_module_details
    __import__(pkg_name)
  File "/Users/jder/oa/levanter/src/levanter/__init__.py", line 5, in <module>
    import levanter.eval as eval
  File "/Users/jder/oa/levanter/src/levanter/eval.py", line 19, in <module>
    from levanter.callbacks import StepInfo
  File "/Users/jder/oa/levanter/src/levanter/callbacks.py", line 29, in <module>
    from levanter.trainer_state import TrainerState
  File "/Users/jder/oa/levanter/src/levanter/trainer_state.py", line 10, in <module>
    from optax import GradientTransformation, OptState
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/optax/__init__.py", line 17, in <module>
    from optax import contrib
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/optax/contrib/__init__.py", line 21, in <module>
    from optax.contrib._dadapt_adamw import dadapt_adamw
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/optax/contrib/_dadapt_adamw.py", line 27, in <module>
    from optax._src import utils
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/optax/_src/utils.py", line 25, in <module>
    import jax.scipy.stats.norm as multivariate_normal
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/scipy/stats/__init__.py", line 40, in <module>
    from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/scipy/stats/kde.py", line 26, in <module>
    from jax.scipy import linalg, special
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/scipy/linalg.py", line 18, in <module>
    from jax._src.scipy.linalg import (
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/scipy/linalg.py", line 408, in <module>
    @_wraps(scipy.linalg.tril)
AttributeError: module 'scipy.linalg' has no attribute 'tril'

Running instead with the latest jax-metal gets further, but ends with a LLVM ERROR: Failed to infer result types

full transcript
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
INFO:levanter.distributed:Not initializing jax.distributed because no distributed config was provided, and no cluster was detected.
INFO:2025-01-23 13:35:57,240:jax._src.xla_bridge:945: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-01-23 13:35:57,240:jax._src.xla_bridge:945: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)
WARNING:2025-01-23 13:35:57,240:jax._src.xla_bridge:1018: Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING:jax._src.xla_bridge:Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1737657357.241082 136846672 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M2 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

I0000 00:00:1737657357.255176 136846672 service.cc:145] XLA service 0x600000347500 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1737657357.255195 136846672 service.cc:153]   StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1737657357.256698 136846672 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1737657357.256711 136846672 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.
INFO:levanter.trainer:Setting run id to 7arqdyx3
2025-01-23T13:35:57 - 0 - levanter.tracker.wandb - wandb.py:233 - INFO :: Setting wandb code_dir to .
2025-01-23T13:35:57 - 0 - levanter.tracker.wandb - wandb.py:251 - WARNING :: Could not find git repo at .
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 3
wandb: You chose "Don't visualize my results"
wandb: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
wandb: WARNING `resume` will be ignored since W&B syncing is set to `offline`. Starting a new run with run id 7arqdyx3.
wandb: Tracking run with wandb version 0.19.4
wandb: W&B syncing is set to `offline` in this directory.  
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
2025-01-23T13:36:14 - 0 - levanter.distributed - distributed.py:215 - INFO :: No auto-discovered ray address found. Using ray.init('local').
2025-01-23T13:36:14 - 0 - levanter.distributed - distributed.py:267 - INFO :: ray.init(address='local', namespace='levanter', **{})
/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = _posixsubprocess.fork_exec(
/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = _posixsubprocess.fork_exec(
2025-01-23 13:36:15,077	INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265 
2025-01-23T13:36:15 - 0 - levanter.tracker.wandb - wandb.py:233 - INFO :: Setting wandb code_dir to .
2025-01-23T13:36:15 - 0 - levanter.tracker.wandb - wandb.py:251 - WARNING :: Could not find git repo at .
train:   0%|                                                                                                                           | 0/100 [00:00<?, ?it/s]2025-01-23T13:36:15 - 0 - levanter.store.cache - cache.py:302 - INFO :: Loading cache from cache/validation
2025-01-23T13:36:15 - 0 - levanter.store.cache - cache.py:302 - INFO :: Loading cache from cache/train
2025-01-23T13:36:17 - 0 - levanter.data.text - text.py:1105 - INFO :: Building cache for train...
2025-01-23T13:36:17 - 0 - levanter.store.cache - cache.py:302 - INFO :: Loading cache from cache/train
/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/haliax/partitioning.py:128: RuntimeWarning: Sharding constraints are not supported in jit on metal
  warnings.warn("Sharding constraints are not supported in jit on metal", RuntimeWarning)
/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:1153: UserWarning: Some donated buffers were not usable: ShapedArray(uint32[2]).
Donation is not implemented for ('METAL',).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
  warnings.warn("Some donated buffers were not usable:"
cache/train: tokenizing:   0%|          | 0/8 [00:00<?, ?shard/s]
(writer::cache/train pid=26972) 2025-01-23 13:36:20,572 - INFO - Starting writer task
(writer::cache/train pid=26972) 2025-01-23 13:36:20,606 - INFO - Waiting for first group 0 to finish
cache/train: tokenizing:  12%|█▎        | 1/8 [00:03<00:22,  3.25s/shard]
(writer::cache/train pid=26972) 2025-01-23 13:36:22,257 - INFO - First group 0 finished. Copying other groups into permanent cache.
(tokenize::cache/train/___temp::0 pid=26983) 2025-01-23 13:36:22,253 - INFO - Shard 0 already processed.
2025-01-23T13:36:22 - 0 - __main__ - train_lm.py:195 - INFO :: No checkpoint found. Starting from scratch.
(tokenize::cache/train/___temp::1 pid=26983) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
(tokenize::cache/train/___temp::1 pid=26983) To disable this warning, you can either:
(tokenize::cache/train/___temp::1 pid=26983) 	- Avoid using `tokenizers` before the fork if possible
(tokenize::cache/train/___temp::1 pid=26983) 	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:1153: UserWarning: Some donated buffers were not usable: ShapedArray(int32[], weak_type=True), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32,3,4,8]), ShapedArray(float32[2,3,4,8]), ShapedArray(float32[2,4,8,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32,128]), ShapedArray(float32[2,128]), ShapedArray(float32[2,128,32]), ShapedArray(float32[2,32]), ShapedArray(float32[32]), ShapedArray(float32[32]), ShapedArray(float32[50257,32]), ShapedArray(float32[1024,32]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32,3,4,8]), ShapedArray(float32[2,3,4,8]), ShapedArray(float32[2,4,8,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32,128]), ShapedArray(float32[2,128]), ShapedArray(float32[2,128,32]), ShapedArray(float32[2,32]), ShapedArray(float32[32]), ShapedArray(float32[32]), ShapedArray(float32[50257,32]), ShapedArray(float32[1024,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32,3,4,8]), ShapedArray(float32[2,3,4,8]), ShapedArray(float32[2,4,8,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32,128]), ShapedArray(float32[2,128]), ShapedArray(float32[2,128,32]), ShapedArray(float32[2,32]), ShapedArray(float32[32]), ShapedArray(float32[32]), ShapedArray(float32[50257,32]), ShapedArray(float32[1024,32]), ShapedArray(uint32[2]).
Donation is not implemented for ('METAL',).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
  warnings.warn("Some donated buffers were not usable:"
LLVM ERROR: Failed to infer result type(s).
*** SIGABRT received at time=1737657385 ***
PC: @        0x19db7a600  (unknown)  __pthread_kill
    @        0x12d165398  (unknown)  absl::lts_20230802::AbslFailureSignalHandler()
    @        0x19dbe8184  (unknown)  _sigtramp
    @        0x19dbb2f70  (unknown)  pthread_kill
    @        0x19dabf908  (unknown)  abort
    @        0x30df1688c  (unknown)  llvm::report_fatal_error()
    @        0x30df166c4  (unknown)  llvm::report_fatal_error()
    @        0x3099cb7d4  (unknown)  mlir::mps::PermuteOp::build()
    @        0x30983a0b0  (unknown)  mlir::OpBuilder::create<>()
    @        0x309839a30  (unknown)  mlir::mps::(anonymous namespace)::BroadcastInDimConverter::matchAndRewrite()
    @        0x3098393cc  (unknown)  mlir::OpConversionPattern<>::matchAndRewrite()
    @        0x30db09834  (unknown)  mlir::ConversionPattern::matchAndRewrite()
    @        0x30db4d018  (unknown)  llvm::function_ref<>::callback_fn<>()
    @        0x30db4a930  (unknown)  mlir::PatternApplicator::matchAndRewrite()
    @        0x30db09e8c  (unknown)  (anonymous namespace)::OperationLegalizer::legalize()
    @        0x30db09900  (unknown)  mlir::OperationConverter::convert()
    @        0x30db0a05c  (unknown)  mlir::OperationConverter::convertOperations()
    @        0x30db11008  (unknown)  mlir::applyFullConversion()
    @        0x30980b404  (unknown)  mlir::mps::(anonymous namespace)::ConvertHLOToMPSPass::runOnOperation()
    @        0x30dd5e78c  (unknown)  mlir::detail::OpToOpPassAdaptor::run()
    @        0x30dd5ec98  (unknown)  mlir::detail::OpToOpPassAdaptor::runPipeline()
    @        0x30dd5fdd8  (unknown)  mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl()
    @        0x30dd5e63c  (unknown)  mlir::detail::OpToOpPassAdaptor::run()
    @        0x30dd60ca4  (unknown)  mlir::PassManager::runPasses()
    @        0x30dd60b28  (unknown)  mlir::PassManager::run()
    @        0x309809fd8  (unknown)  compileMlirHLOToMPS
    @        0x3097fb95c  (unknown)  xla::mps::MetalStreamExecutorClient::Compile()
    @        0x30bdad940  (unknown)  std::__1::__variant_detail::__visitation::__base::__dispatcher<>::__dispatch[abi:ne180100]<>()
    @        0x30bda06a4  (unknown)  pjrt::PJRT_Client_Compile()
    @        0x138150d34  (unknown)  xla::InitializeArgsAndCompile()
    @        0x1381514d4  (unknown)  xla::PjRtCApiClient::Compile()
    @        0x13c20d1d8  (unknown)  xla::ifrt::PjRtLoadedExecutable::Create()
    @        0x13c2093ec  (unknown)  xla::ifrt::PjRtCompiler::Compile()
    @ ... and at least 29 more frames
[2025-01-23 13:36:25,264 E 26616 136846672] logging.cc:460: *** SIGABRT received at time=1737657385 ***
[2025-01-23 13:36:25,264 E 26616 136846672] logging.cc:460: PC: @        0x19db7a600  (unknown)  __pthread_kill
[2025-01-23 13:36:25,265 E 26616 136846672] logging.cc:460:     @        0x12d1654b8  (unknown)  absl::lts_20230802::AbslFailureSignalHandler()
[2025-01-23 13:36:25,265 E 26616 136846672] logging.cc:460:     @        0x19dbe8184  (unknown)  _sigtramp
[2025-01-23 13:36:25,265 E 26616 136846672] logging.cc:460:     @        0x19dbb2f70  (unknown)  pthread_kill
[2025-01-23 13:36:25,265 E 26616 136846672] logging.cc:460:     @        0x19dabf908  (unknown)  abort
[2025-01-23 13:36:25,266 E 26616 136846672] logging.cc:460:     @        0x30df1688c  (unknown)  llvm::report_fatal_error()
[2025-01-23 13:36:25,266 E 26616 136846672] logging.cc:460:     @        0x30df166c4  (unknown)  llvm::report_fatal_error()
[2025-01-23 13:36:25,266 E 26616 136846672] logging.cc:460:     @        0x3099cb7d4  (unknown)  mlir::mps::PermuteOp::build()
[2025-01-23 13:36:25,267 E 26616 136846672] logging.cc:460:     @        0x30983a0b0  (unknown)  mlir::OpBuilder::create<>()
[2025-01-23 13:36:25,267 E 26616 136846672] logging.cc:460:     @        0x309839a30  (unknown)  mlir::mps::(anonymous namespace)::BroadcastInDimConverter::matchAndRewrite()
[2025-01-23 13:36:25,267 E 26616 136846672] logging.cc:460:     @        0x3098393cc  (unknown)  mlir::OpConversionPattern<>::matchAndRewrite()
[2025-01-23 13:36:25,268 E 26616 136846672] logging.cc:460:     @        0x30db09834  (unknown)  mlir::ConversionPattern::matchAndRewrite()
[2025-01-23 13:36:25,268 E 26616 136846672] logging.cc:460:     @        0x30db4d018  (unknown)  llvm::function_ref<>::callback_fn<>()
[2025-01-23 13:36:25,269 E 26616 136846672] logging.cc:460:     @        0x30db4a930  (unknown)  mlir::PatternApplicator::matchAndRewrite()
[2025-01-23 13:36:25,269 E 26616 136846672] logging.cc:460:     @        0x30db09e8c  (unknown)  (anonymous namespace)::OperationLegalizer::legalize()
[2025-01-23 13:36:25,269 E 26616 136846672] logging.cc:460:     @        0x30db09900  (unknown)  mlir::OperationConverter::convert()
[2025-01-23 13:36:25,270 E 26616 136846672] logging.cc:460:     @        0x30db0a05c  (unknown)  mlir::OperationConverter::convertOperations()
[2025-01-23 13:36:25,270 E 26616 136846672] logging.cc:460:     @        0x30db11008  (unknown)  mlir::applyFullConversion()
[2025-01-23 13:36:25,270 E 26616 136846672] logging.cc:460:     @        0x30980b404  (unknown)  mlir::mps::(anonymous namespace)::ConvertHLOToMPSPass::runOnOperation()
[2025-01-23 13:36:25,271 E 26616 136846672] logging.cc:460:     @        0x30dd5e78c  (unknown)  mlir::detail::OpToOpPassAdaptor::run()
[2025-01-23 13:36:25,271 E 26616 136846672] logging.cc:460:     @        0x30dd5ec98  (unknown)  mlir::detail::OpToOpPassAdaptor::runPipeline()
[2025-01-23 13:36:25,271 E 26616 136846672] logging.cc:460:     @        0x30dd5fdd8  (unknown)  mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl()
[2025-01-23 13:36:25,272 E 26616 136846672] logging.cc:460:     @        0x30dd5e63c  (unknown)  mlir::detail::OpToOpPassAdaptor::run()
[2025-01-23 13:36:25,272 E 26616 136846672] logging.cc:460:     @        0x30dd60ca4  (unknown)  mlir::PassManager::runPasses()
[2025-01-23 13:36:25,273 E 26616 136846672] logging.cc:460:     @        0x30dd60b28  (unknown)  mlir::PassManager::run()
[2025-01-23 13:36:25,273 E 26616 136846672] logging.cc:460:     @        0x309809fd8  (unknown)  compileMlirHLOToMPS
[2025-01-23 13:36:25,273 E 26616 136846672] logging.cc:460:     @        0x3097fb95c  (unknown)  xla::mps::MetalStreamExecutorClient::Compile()
[2025-01-23 13:36:25,274 E 26616 136846672] logging.cc:460:     @        0x30bdad940  (unknown)  std::__1::__variant_detail::__visitation::__base::__dispatcher<>::__dispatch[abi:ne180100]<>()
[2025-01-23 13:36:25,274 E 26616 136846672] logging.cc:460:     @        0x30bda06a4  (unknown)  pjrt::PJRT_Client_Compile()
[2025-01-23 13:36:25,274 E 26616 136846672] logging.cc:460:     @        0x138150d34  (unknown)  xla::InitializeArgsAndCompile()
[2025-01-23 13:36:25,275 E 26616 136846672] logging.cc:460:     @        0x1381514d4  (unknown)  xla::PjRtCApiClient::Compile()
[2025-01-23 13:36:25,275 E 26616 136846672] logging.cc:460:     @        0x13c20d1d8  (unknown)  xla::ifrt::PjRtLoadedExecutable::Create()
[2025-01-23 13:36:25,276 E 26616 136846672] logging.cc:460:     @        0x13c2093ec  (unknown)  xla::ifrt::PjRtCompiler::Compile()
[2025-01-23 13:36:25,276 E 26616 136846672] logging.cc:460:     @ ... and at least 29 more frames
Fatal Python error: Aborted

Stack (most recent call first):
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/compiler.py", line 315 in backend_compile
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 333 in wrapper
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/compiler.py", line 388 in compile_or_get_cached
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2723 in _cached_compilation
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2922 in from_hlo
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2419 in compile
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1669 in _pjit_call_impl_python
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 198 in _python_pjit_helper
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 340 in cache_miss
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180 in reraise_with_filtered_traceback
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/haliax/partitioning.py", line 337 in _call
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/equinox/_module.py", line 1096 in __call__
  File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/haliax/partitioning.py", line 261 in __call__
  File "/Users/jder/oa/levanter/src/levanter/trainer.py", line 401 in train_step
  File "/Users/jder/oa/levanter/src/levanter/trainer.py", line 424 in training_steps
  File "/Users/jder/oa/levanter/src/levanter/trainer.py", line 435 in train
  File "/Users/jder/oa/levanter/src/levanter/main/train_lm.py", line 292 in main
  File "/Users/jder/oa/levanter/src/levanter/config.py", line 84 in wrapper_inner
  File "/Users/jder/oa/levanter/src/levanter/main/train_lm.py", line 305 in <module>
  File "/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 86 in _run_code
  File "/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 196 in _run_module_as_main

Extension modules: jaxlib.cpu_feature_guard, numpy._core._multiarray_umath, numpy.linalg._umath_linalg, zstandard.backend_c, pyarrow.lib, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, pandas._libs.tslibs.ccalendar, pandas._libs.tslibs.np_datetime, pandas._libs.tslibs.dtypes, pandas._libs.tslibs.base, pandas._libs.tslibs.nattype, pandas._libs.tslibs.timezones, pandas._libs.tslibs.fields, pandas._libs.tslibs.timedeltas, pandas._libs.tslibs.tzconversion, pandas._libs.tslibs.timestamps, pandas._libs.properties, pandas._libs.tslibs.offsets, pandas._libs.tslibs.strptime, pandas._libs.tslibs.parsing, pandas._libs.tslibs.conversion, pandas._libs.tslibs.period, pandas._libs.tslibs.vectorized, pandas._libs.ops_dispatch, pandas._libs.missing, pandas._libs.hashtable, pandas._libs.algos, pandas._libs.interval, pandas._libs.lib, pyarrow._compute, pandas._libs.ops, pandas._libs.hashing, pandas._libs.arrays, pandas._libs.tslib, pandas._libs.sparse, pandas._libs.internals, pandas._libs.indexing, pandas._libs.index, pandas._libs.writers, pandas._libs.join, pandas._libs.window.aggregations, pandas._libs.window.indexers, pandas._libs.reshape, pandas._libs.groupby, pandas._libs.json, pandas._libs.parsers, pandas._libs.testing, charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, yaml._yaml, pyarrow._parquet, pyarrow._fs, pyarrow._azurefs, pyarrow._hdfs, pyarrow._gcsfs, pyarrow._s3fs, multidict._multidict, yarl._quoting_c, propcache._helpers_c, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket.mask, aiohttp._websocket.reader_c, frozenlist._frozenlist, xxhash._xxhash, pyarrow._json, pyarrow._acero, pyarrow._csv, pyarrow._substrait, pyarrow._dataset, pyarrow._dataset_orc, pyarrow._parquet_encryption, pyarrow._dataset_parquet_encryption, pyarrow._dataset_parquet, google._upb._message, grpc._cython.cygrpc, msgpack._cmsgpack, psutil._psutil_osx, psutil._psutil_posix, setproctitle, ray._raylet, PIL._imaging, kiwisolver._cext, regex._regex, scipy._lib._ccallback_c, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg._matfuncs_expm, scipy.linalg._linalg_pythran, scipy.linalg.cython_blas, scipy.linalg._decomp_update, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.linalg._propack._spropack, scipy.sparse.linalg._propack._dpropack, scipy.sparse.linalg._propack._cpropack, scipy.sparse.linalg._propack._zpropack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.io.matlab._mio_utils, scipy.io.matlab._streams, scipy.io.matlab._mio5_utils (total: 124)
[1]    26616 abort      python -m levanter.main.train_lm --config config/gpt2_nano.yaml
/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

Maybe this is just worth noting in the readme? Happy to provide any more info that would be helpful, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant