You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
On macOS 15.1.1 using python 3.10.14 with an M2 Max processor, running through the installation instructions & sample test command doesn't work for me. This with with a checkout of latest levanter (237851b) with a fresh venv. I tried a few variants:
Following the instructions directly, pip install jax-metal==0.0.5 followed by pip install -e . produces a dependency conflict:
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
jax-metal 0.0.5 requires jax==0.4.20, but you have jax 0.5.0 which is incompatible.
jax-metal 0.0.5 requires jaxlib==0.4.20, but you have jaxlib 0.5.0 which is incompatible.
Doing them in one go ie pip install jax-metal==0.0.5 -e . installs successfully but then immediately errors out with what looks like incompatible numpy versions when trying to run the demo command python -m levanter.main.train_lm --config config/gpt2_nano.yaml, presumably due to downgrading of other dependencies given the old jax-metal version:
full error
A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.2 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.
If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.
Traceback (most recent call last): File "/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 187, in _run_module_as_main
mod_name, mod_spec, code = _get_module_details(mod_name, _Error)
File "/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 110, in _get_module_details
__import__(pkg_name)
File "/Users/jder/oa/levanter/src/levanter/__init__.py", line 1, in <module>
import levanter.checkpoint as checkpoint
File "/Users/jder/oa/levanter/src/levanter/checkpoint.py", line 15, in <module>
import equinox
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/equinox/__init__.py", line 3, in <module>
from . import debug as debug, internal as internal, nn as nn
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/equinox/debug/__init__.py", line 1, in <module>
from ._announce_transform import announce_transform as announce_transform
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/equinox/debug/_announce_transform.py", line 4, in <module>
import jax
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/__init__.py", line 39, in <module>
from jax import config as _config_module
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/config.py", line 15, in <module>
from jax._src.config import config as _deprecated_config # noqa: F401
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/config.py", line 28, in <module>
from jax._src import lib
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/lib/__init__.py", line 90, in <module>
import jaxlib.xla_client as xla_client
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jaxlib/xla_client.py", line 29, in <module>
from . import xla_extension as _xla
AttributeError: _ARRAY_API not found
Traceback (most recent call last):
File "/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 187, in _run_module_as_main
mod_name, mod_spec, code = _get_module_details(mod_name, _Error)
File "/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 110, in _get_module_details
__import__(pkg_name)
File "/Users/jder/oa/levanter/src/levanter/__init__.py", line 5, in <module>
import levanter.eval as eval
File "/Users/jder/oa/levanter/src/levanter/eval.py", line 19, in <module>
from levanter.callbacks import StepInfo
File "/Users/jder/oa/levanter/src/levanter/callbacks.py", line 29, in <module>
from levanter.trainer_state import TrainerState
File "/Users/jder/oa/levanter/src/levanter/trainer_state.py", line 10, in <module>
from optax import GradientTransformation, OptState
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/optax/__init__.py", line 17, in <module>
from optax import contrib
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/optax/contrib/__init__.py", line 21, in <module>
from optax.contrib._dadapt_adamw import dadapt_adamw
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/optax/contrib/_dadapt_adamw.py", line 27, in <module>
from optax._src import utils
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/optax/_src/utils.py", line 25, in <module>
import jax.scipy.stats.norm as multivariate_normal
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/scipy/stats/__init__.py", line 40, in <module>
from jax._src.scipy.stats.kde import gaussian_kde as gaussian_kde
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/scipy/stats/kde.py", line 26, in <module>
from jax.scipy import linalg, special
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/scipy/linalg.py", line 18, in <module>
from jax._src.scipy.linalg import (
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/scipy/linalg.py", line 408, in <module>
@_wraps(scipy.linalg.tril)
AttributeError: module 'scipy.linalg' has no attribute 'tril'
Running instead with the latest jax-metal gets further, but ends with a LLVM ERROR: Failed to infer result types
full transcript
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
None of PyTorch, TensorFlow >= 2.0, or Flax have been found. Models won't be available and only tokenizers, configuration and file/data utilities can be used.
INFO:levanter.distributed:Not initializing jax.distributed because no distributed config was provided, and no cluster was detected.
INFO:2025-01-23 13:35:57,240:jax._src.xla_bridge:945: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:jax._src.xla_bridge:Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-01-23 13:35:57,240:jax._src.xla_bridge:945: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)
INFO:jax._src.xla_bridge:Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/homebrew/lib/libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/opt/homebrew/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file), '/usr/local/lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache)
WARNING:2025-01-23 13:35:57,240:jax._src.xla_bridge:1018: Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING:jax._src.xla_bridge:Platform 'METAL' is experimental and not all JAX functionality may be correctly supported!
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1737657357.241082 136846672 mps_client.cc:510] WARNING: JAX Apple GPU support is experimental and not all JAX functionality is correctly supported!
Metal device set to: Apple M2 Max
systemMemory: 64.00 GB
maxCacheSize: 24.00 GB
I0000 00:00:1737657357.255176 136846672 service.cc:145] XLA service 0x600000347500 initialized for platform METAL (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1737657357.255195 136846672 service.cc:153] StreamExecutor device (0): Metal, <undefined>
I0000 00:00:1737657357.256698 136846672 mps_client.cc:406] Using Simple allocator.
I0000 00:00:1737657357.256711 136846672 mps_client.cc:384] XLA backend will use up to 51539214336 bytes on device 0 for SimpleAllocator.
INFO:levanter.trainer:Setting run id to 7arqdyx3
2025-01-23T13:35:57 - 0 - levanter.tracker.wandb - wandb.py:233 - INFO :: Setting wandb code_dir to .
2025-01-23T13:35:57 - 0 - levanter.tracker.wandb - wandb.py:251 - WARNING :: Could not find git repo at .
wandb: (1) Create a W&B account
wandb: (2) Use an existing W&B account
wandb: (3) Don't visualize my results
wandb: Enter your choice: 3
wandb: You chose "Don't visualize my results"
wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: WARNING `resume` will be ignored since W&B syncing is set to `offline`. Starting a new run with run id 7arqdyx3.
wandb: Tracking run with wandb version 0.19.4
wandb: W&B syncing is set to `offline` in this directory.
wandb: Run `wandb online` or set WANDB_MODE=online to enable cloud syncing.
2025-01-23T13:36:14 - 0 - levanter.distributed - distributed.py:215 - INFO :: No auto-discovered ray address found. Using ray.init('local').
2025-01-23T13:36:14 - 0 - levanter.distributed - distributed.py:267 - INFO :: ray.init(address='local', namespace='levanter', **{})
/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = _posixsubprocess.fork_exec(
/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/subprocess.py:1796: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
self.pid = _posixsubprocess.fork_exec(
2025-01-23 13:36:15,077 INFO worker.py:1832 -- Started a local Ray instance. View the dashboard at 127.0.0.1:8265
2025-01-23T13:36:15 - 0 - levanter.tracker.wandb - wandb.py:233 - INFO :: Setting wandb code_dir to .
2025-01-23T13:36:15 - 0 - levanter.tracker.wandb - wandb.py:251 - WARNING :: Could not find git repo at .
train: 0%| | 0/100 [00:00<?, ?it/s]2025-01-23T13:36:15 - 0 - levanter.store.cache - cache.py:302 - INFO :: Loading cache from cache/validation
2025-01-23T13:36:15 - 0 - levanter.store.cache - cache.py:302 - INFO :: Loading cache from cache/train
2025-01-23T13:36:17 - 0 - levanter.data.text - text.py:1105 - INFO :: Building cache for train...
2025-01-23T13:36:17 - 0 - levanter.store.cache - cache.py:302 - INFO :: Loading cache from cache/train
/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/haliax/partitioning.py:128: RuntimeWarning: Sharding constraints are not supported in jit on metal
warnings.warn("Sharding constraints are not supported in jit on metal", RuntimeWarning)
/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:1153: UserWarning: Some donated buffers were not usable: ShapedArray(uint32[2]).
Donation is not implemented for ('METAL',).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
warnings.warn("Some donated buffers were not usable:"
cache/train: tokenizing: 0%| | 0/8 [00:00<?, ?shard/s]
(writer::cache/train pid=26972) 2025-01-23 13:36:20,572 - INFO - Starting writer task
(writer::cache/train pid=26972) 2025-01-23 13:36:20,606 - INFO - Waiting for first group 0 to finish
cache/train: tokenizing: 12%|█▎ | 1/8 [00:03<00:22, 3.25s/shard]
(writer::cache/train pid=26972) 2025-01-23 13:36:22,257 - INFO - First group 0 finished. Copying other groups into permanent cache.
(tokenize::cache/train/___temp::0 pid=26983) 2025-01-23 13:36:22,253 - INFO - Shard 0 already processed.
2025-01-23T13:36:22 - 0 - __main__ - train_lm.py:195 - INFO :: No checkpoint found. Starting from scratch.
(tokenize::cache/train/___temp::1 pid=26983) huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
(tokenize::cache/train/___temp::1 pid=26983) To disable this warning, you can either:
(tokenize::cache/train/___temp::1 pid=26983) - Avoid using `tokenizers` before the fork if possible
(tokenize::cache/train/___temp::1 pid=26983) - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/interpreters/mlir.py:1153: UserWarning: Some donated buffers were not usable: ShapedArray(int32[], weak_type=True), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32,3,4,8]), ShapedArray(float32[2,3,4,8]), ShapedArray(float32[2,4,8,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32,128]), ShapedArray(float32[2,128]), ShapedArray(float32[2,128,32]), ShapedArray(float32[2,32]), ShapedArray(float32[32]), ShapedArray(float32[32]), ShapedArray(float32[50257,32]), ShapedArray(float32[1024,32]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(int32[]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32,3,4,8]), ShapedArray(float32[2,3,4,8]), ShapedArray(float32[2,4,8,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32,128]), ShapedArray(float32[2,128]), ShapedArray(float32[2,128,32]), ShapedArray(float32[2,32]), ShapedArray(float32[32]), ShapedArray(float32[32]), ShapedArray(float32[50257,32]), ShapedArray(float32[1024,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32,3,4,8]), ShapedArray(float32[2,3,4,8]), ShapedArray(float32[2,4,8,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32]), ShapedArray(float32[2,32,128]), ShapedArray(float32[2,128]), ShapedArray(float32[2,128,32]), ShapedArray(float32[2,32]), ShapedArray(float32[32]), ShapedArray(float32[32]), ShapedArray(float32[50257,32]), ShapedArray(float32[1024,32]), ShapedArray(uint32[2]).
Donation is not implemented for ('METAL',).
See an explanation at https://jax.readthedocs.io/en/latest/faq.html#buffer-donation.
warnings.warn("Some donated buffers were not usable:"
LLVM ERROR: Failed to infer result type(s).
*** SIGABRT received at time=1737657385 ***
PC: @ 0x19db7a600 (unknown) __pthread_kill
@ 0x12d165398 (unknown) absl::lts_20230802::AbslFailureSignalHandler()
@ 0x19dbe8184 (unknown) _sigtramp
@ 0x19dbb2f70 (unknown) pthread_kill
@ 0x19dabf908 (unknown) abort
@ 0x30df1688c (unknown) llvm::report_fatal_error()
@ 0x30df166c4 (unknown) llvm::report_fatal_error()
@ 0x3099cb7d4 (unknown) mlir::mps::PermuteOp::build()
@ 0x30983a0b0 (unknown) mlir::OpBuilder::create<>()
@ 0x309839a30 (unknown) mlir::mps::(anonymous namespace)::BroadcastInDimConverter::matchAndRewrite()
@ 0x3098393cc (unknown) mlir::OpConversionPattern<>::matchAndRewrite()
@ 0x30db09834 (unknown) mlir::ConversionPattern::matchAndRewrite()
@ 0x30db4d018 (unknown) llvm::function_ref<>::callback_fn<>()
@ 0x30db4a930 (unknown) mlir::PatternApplicator::matchAndRewrite()
@ 0x30db09e8c (unknown) (anonymous namespace)::OperationLegalizer::legalize()
@ 0x30db09900 (unknown) mlir::OperationConverter::convert()
@ 0x30db0a05c (unknown) mlir::OperationConverter::convertOperations()
@ 0x30db11008 (unknown) mlir::applyFullConversion()
@ 0x30980b404 (unknown) mlir::mps::(anonymous namespace)::ConvertHLOToMPSPass::runOnOperation()
@ 0x30dd5e78c (unknown) mlir::detail::OpToOpPassAdaptor::run()
@ 0x30dd5ec98 (unknown) mlir::detail::OpToOpPassAdaptor::runPipeline()
@ 0x30dd5fdd8 (unknown) mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl()
@ 0x30dd5e63c (unknown) mlir::detail::OpToOpPassAdaptor::run()
@ 0x30dd60ca4 (unknown) mlir::PassManager::runPasses()
@ 0x30dd60b28 (unknown) mlir::PassManager::run()
@ 0x309809fd8 (unknown) compileMlirHLOToMPS
@ 0x3097fb95c (unknown) xla::mps::MetalStreamExecutorClient::Compile()
@ 0x30bdad940 (unknown) std::__1::__variant_detail::__visitation::__base::__dispatcher<>::__dispatch[abi:ne180100]<>()
@ 0x30bda06a4 (unknown) pjrt::PJRT_Client_Compile()
@ 0x138150d34 (unknown) xla::InitializeArgsAndCompile()
@ 0x1381514d4 (unknown) xla::PjRtCApiClient::Compile()
@ 0x13c20d1d8 (unknown) xla::ifrt::PjRtLoadedExecutable::Create()
@ 0x13c2093ec (unknown) xla::ifrt::PjRtCompiler::Compile()
@ ... and at least 29 more frames
[2025-01-23 13:36:25,264 E 26616 136846672] logging.cc:460: *** SIGABRT received at time=1737657385 ***
[2025-01-23 13:36:25,264 E 26616 136846672] logging.cc:460: PC: @ 0x19db7a600 (unknown) __pthread_kill
[2025-01-23 13:36:25,265 E 26616 136846672] logging.cc:460: @ 0x12d1654b8 (unknown) absl::lts_20230802::AbslFailureSignalHandler()
[2025-01-23 13:36:25,265 E 26616 136846672] logging.cc:460: @ 0x19dbe8184 (unknown) _sigtramp
[2025-01-23 13:36:25,265 E 26616 136846672] logging.cc:460: @ 0x19dbb2f70 (unknown) pthread_kill
[2025-01-23 13:36:25,265 E 26616 136846672] logging.cc:460: @ 0x19dabf908 (unknown) abort
[2025-01-23 13:36:25,266 E 26616 136846672] logging.cc:460: @ 0x30df1688c (unknown) llvm::report_fatal_error()
[2025-01-23 13:36:25,266 E 26616 136846672] logging.cc:460: @ 0x30df166c4 (unknown) llvm::report_fatal_error()
[2025-01-23 13:36:25,266 E 26616 136846672] logging.cc:460: @ 0x3099cb7d4 (unknown) mlir::mps::PermuteOp::build()
[2025-01-23 13:36:25,267 E 26616 136846672] logging.cc:460: @ 0x30983a0b0 (unknown) mlir::OpBuilder::create<>()
[2025-01-23 13:36:25,267 E 26616 136846672] logging.cc:460: @ 0x309839a30 (unknown) mlir::mps::(anonymous namespace)::BroadcastInDimConverter::matchAndRewrite()
[2025-01-23 13:36:25,267 E 26616 136846672] logging.cc:460: @ 0x3098393cc (unknown) mlir::OpConversionPattern<>::matchAndRewrite()
[2025-01-23 13:36:25,268 E 26616 136846672] logging.cc:460: @ 0x30db09834 (unknown) mlir::ConversionPattern::matchAndRewrite()
[2025-01-23 13:36:25,268 E 26616 136846672] logging.cc:460: @ 0x30db4d018 (unknown) llvm::function_ref<>::callback_fn<>()
[2025-01-23 13:36:25,269 E 26616 136846672] logging.cc:460: @ 0x30db4a930 (unknown) mlir::PatternApplicator::matchAndRewrite()
[2025-01-23 13:36:25,269 E 26616 136846672] logging.cc:460: @ 0x30db09e8c (unknown) (anonymous namespace)::OperationLegalizer::legalize()
[2025-01-23 13:36:25,269 E 26616 136846672] logging.cc:460: @ 0x30db09900 (unknown) mlir::OperationConverter::convert()
[2025-01-23 13:36:25,270 E 26616 136846672] logging.cc:460: @ 0x30db0a05c (unknown) mlir::OperationConverter::convertOperations()
[2025-01-23 13:36:25,270 E 26616 136846672] logging.cc:460: @ 0x30db11008 (unknown) mlir::applyFullConversion()
[2025-01-23 13:36:25,270 E 26616 136846672] logging.cc:460: @ 0x30980b404 (unknown) mlir::mps::(anonymous namespace)::ConvertHLOToMPSPass::runOnOperation()
[2025-01-23 13:36:25,271 E 26616 136846672] logging.cc:460: @ 0x30dd5e78c (unknown) mlir::detail::OpToOpPassAdaptor::run()
[2025-01-23 13:36:25,271 E 26616 136846672] logging.cc:460: @ 0x30dd5ec98 (unknown) mlir::detail::OpToOpPassAdaptor::runPipeline()
[2025-01-23 13:36:25,271 E 26616 136846672] logging.cc:460: @ 0x30dd5fdd8 (unknown) mlir::detail::OpToOpPassAdaptor::runOnOperationAsyncImpl()
[2025-01-23 13:36:25,272 E 26616 136846672] logging.cc:460: @ 0x30dd5e63c (unknown) mlir::detail::OpToOpPassAdaptor::run()
[2025-01-23 13:36:25,272 E 26616 136846672] logging.cc:460: @ 0x30dd60ca4 (unknown) mlir::PassManager::runPasses()
[2025-01-23 13:36:25,273 E 26616 136846672] logging.cc:460: @ 0x30dd60b28 (unknown) mlir::PassManager::run()
[2025-01-23 13:36:25,273 E 26616 136846672] logging.cc:460: @ 0x309809fd8 (unknown) compileMlirHLOToMPS
[2025-01-23 13:36:25,273 E 26616 136846672] logging.cc:460: @ 0x3097fb95c (unknown) xla::mps::MetalStreamExecutorClient::Compile()
[2025-01-23 13:36:25,274 E 26616 136846672] logging.cc:460: @ 0x30bdad940 (unknown) std::__1::__variant_detail::__visitation::__base::__dispatcher<>::__dispatch[abi:ne180100]<>()
[2025-01-23 13:36:25,274 E 26616 136846672] logging.cc:460: @ 0x30bda06a4 (unknown) pjrt::PJRT_Client_Compile()
[2025-01-23 13:36:25,274 E 26616 136846672] logging.cc:460: @ 0x138150d34 (unknown) xla::InitializeArgsAndCompile()
[2025-01-23 13:36:25,275 E 26616 136846672] logging.cc:460: @ 0x1381514d4 (unknown) xla::PjRtCApiClient::Compile()
[2025-01-23 13:36:25,275 E 26616 136846672] logging.cc:460: @ 0x13c20d1d8 (unknown) xla::ifrt::PjRtLoadedExecutable::Create()
[2025-01-23 13:36:25,276 E 26616 136846672] logging.cc:460: @ 0x13c2093ec (unknown) xla::ifrt::PjRtCompiler::Compile()
[2025-01-23 13:36:25,276 E 26616 136846672] logging.cc:460: @ ... and at least 29 more frames
Fatal Python error: Aborted
Stack (most recent call first):
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/compiler.py", line 315 in backend_compile
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/profiler.py", line 333 in wrapper
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/compiler.py", line 388 in compile_or_get_cached
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2723 in _cached_compilation
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2922 in from_hlo
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 2419 in compile
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 1669 in _pjit_call_impl_python
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 198 in _python_pjit_helper
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/pjit.py", line 340 in cache_miss
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/jax/_src/traceback_util.py", line 180 in reraise_with_filtered_traceback
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/haliax/partitioning.py", line 337 in _call
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/equinox/_module.py", line 1096 in __call__
File "/Users/jder/oa/levanter/venv/lib/python3.10/site-packages/haliax/partitioning.py", line 261 in __call__
File "/Users/jder/oa/levanter/src/levanter/trainer.py", line 401 in train_step
File "/Users/jder/oa/levanter/src/levanter/trainer.py", line 424 in training_steps
File "/Users/jder/oa/levanter/src/levanter/trainer.py", line 435 in train
File "/Users/jder/oa/levanter/src/levanter/main/train_lm.py", line 292 in main
File "/Users/jder/oa/levanter/src/levanter/config.py", line 84 in wrapper_inner
File "/Users/jder/oa/levanter/src/levanter/main/train_lm.py", line 305 in <module>
File "/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 86 in _run_code
File "/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 196 in _run_module_as_main
Extension modules: jaxlib.cpu_feature_guard, numpy._core._multiarray_umath, numpy.linalg._umath_linalg, zstandard.backend_c, pyarrow.lib, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, pandas._libs.tslibs.ccalendar, pandas._libs.tslibs.np_datetime, pandas._libs.tslibs.dtypes, pandas._libs.tslibs.base, pandas._libs.tslibs.nattype, pandas._libs.tslibs.timezones, pandas._libs.tslibs.fields, pandas._libs.tslibs.timedeltas, pandas._libs.tslibs.tzconversion, pandas._libs.tslibs.timestamps, pandas._libs.properties, pandas._libs.tslibs.offsets, pandas._libs.tslibs.strptime, pandas._libs.tslibs.parsing, pandas._libs.tslibs.conversion, pandas._libs.tslibs.period, pandas._libs.tslibs.vectorized, pandas._libs.ops_dispatch, pandas._libs.missing, pandas._libs.hashtable, pandas._libs.algos, pandas._libs.interval, pandas._libs.lib, pyarrow._compute, pandas._libs.ops, pandas._libs.hashing, pandas._libs.arrays, pandas._libs.tslib, pandas._libs.sparse, pandas._libs.internals, pandas._libs.indexing, pandas._libs.index, pandas._libs.writers, pandas._libs.join, pandas._libs.window.aggregations, pandas._libs.window.indexers, pandas._libs.reshape, pandas._libs.groupby, pandas._libs.json, pandas._libs.parsers, pandas._libs.testing, charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, yaml._yaml, pyarrow._parquet, pyarrow._fs, pyarrow._azurefs, pyarrow._hdfs, pyarrow._gcsfs, pyarrow._s3fs, multidict._multidict, yarl._quoting_c, propcache._helpers_c, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket.mask, aiohttp._websocket.reader_c, frozenlist._frozenlist, xxhash._xxhash, pyarrow._json, pyarrow._acero, pyarrow._csv, pyarrow._substrait, pyarrow._dataset, pyarrow._dataset_orc, pyarrow._parquet_encryption, pyarrow._dataset_parquet_encryption, pyarrow._dataset_parquet, google._upb._message, grpc._cython.cygrpc, msgpack._cmsgpack, psutil._psutil_osx, psutil._psutil_posix, setproctitle, ray._raylet, PIL._imaging, kiwisolver._cext, regex._regex, scipy._lib._ccallback_c, scipy.sparse._sparsetools, _csparsetools, scipy.sparse._csparsetools, scipy.linalg._fblas, scipy.linalg._flapack, scipy.linalg.cython_lapack, scipy.linalg._cythonized_array_utils, scipy.linalg._solve_toeplitz, scipy.linalg._decomp_lu_cython, scipy.linalg._matfuncs_sqrtm_triu, scipy.linalg._matfuncs_expm, scipy.linalg._linalg_pythran, scipy.linalg.cython_blas, scipy.linalg._decomp_update, scipy.sparse.linalg._dsolve._superlu, scipy.sparse.linalg._eigen.arpack._arpack, scipy.sparse.linalg._propack._spropack, scipy.sparse.linalg._propack._dpropack, scipy.sparse.linalg._propack._cpropack, scipy.sparse.linalg._propack._zpropack, scipy.sparse.csgraph._tools, scipy.sparse.csgraph._shortest_path, scipy.sparse.csgraph._traversal, scipy.sparse.csgraph._min_spanning_tree, scipy.sparse.csgraph._flow, scipy.sparse.csgraph._matching, scipy.sparse.csgraph._reordering, scipy.io.matlab._mio_utils, scipy.io.matlab._streams, scipy.io.matlab._mio5_utils (total: 124)
[1] 26616 abort python -m levanter.main.train_lm --config config/gpt2_nano.yaml
/opt/homebrew/Cellar/[email protected]/3.10.14_1/Frameworks/Python.framework/Versions/3.10/lib/python3.10/multiprocessing/resource_tracker.py:224: UserWarning: resource_tracker: There appear to be 1 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
Maybe this is just worth noting in the readme? Happy to provide any more info that would be helpful, thanks!
The text was updated successfully, but these errors were encountered:
On macOS 15.1.1 using python 3.10.14 with an M2 Max processor, running through the installation instructions & sample test command doesn't work for me. This with with a checkout of latest levanter (237851b) with a fresh venv. I tried a few variants:
Following the instructions directly,
pip install jax-metal==0.0.5
followed bypip install -e .
produces a dependency conflict:Doing them in one go ie
pip install jax-metal==0.0.5 -e .
installs successfully but then immediately errors out with what looks like incompatible numpy versions when trying to run the demo commandpython -m levanter.main.train_lm --config config/gpt2_nano.yaml
, presumably due to downgrading of other dependencies given the old jax-metal version:full error
Running instead with the latest jax-metal gets further, but ends with a LLVM ERROR: Failed to infer result types
full transcript
Maybe this is just worth noting in the readme? Happy to provide any more info that would be helpful, thanks!
The text was updated successfully, but these errors were encountered: