diff --git a/jax/_src/pallas/mosaic/primitives.py b/jax/_src/pallas/mosaic/primitives.py index 9ea2b59f66f1..0a7bd371a639 100644 --- a/jax/_src/pallas/mosaic/primitives.py +++ b/jax/_src/pallas/mosaic/primitives.py @@ -615,7 +615,7 @@ def dma_start_discharge_rule(in_avals, out_avals, if device_id_len > 1 or len(nonempty_axes) > 1: raise NotImplementedError("Meshes with more than 1 named dimension not " "implemented in dma_start_p") - shard_axis = nonempty_axes[0].name + shard_axis = nonempty_axes[0] my_axis = jax.lax.axis_index(shard_axis) else: raise ValueError(f"Unknown device_id_type: {device_id_type}")