Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unbound axis names within scan within custom partitioning #20864

Closed
nshepperd opened this issue Apr 22, 2024 · 0 comments
Closed

Unbound axis names within scan within custom partitioning #20864

nshepperd opened this issue Apr 22, 2024 · 0 comments
Assignees
Labels
bug Something isn't working

Comments

@nshepperd
Copy link

Description

I encountered this problem while implementing a ring attention algorithm using custom_partitioning. A custom partitioning of jnp.sum using the following contrived partitioning implementation (see full self contained code at the end of this report) throws an error, saying the axis name cannot be found:

    def part_func(x, axis_name):
        def f(carry, part):
            carry += jax.lax.psum(jnp.sum(part), axis_name=axis_name)
            return carry, None
        return jax.lax.scan(f, 0, x)[0]
# NameError: unbound axis name: x. The following axis names (e.g. defined by pmap) are available to collective operations: []

However, it works fine if the part_func does not contain a scan, such as:

    def part_func_without_scan(x, axis_name):
        return jax.lax.psum(jnp.sum(x), axis_name=axis_name)

This also works if shard_map is used to directly call part_func instead of via custom partitioning, but only if check_rep=False (check_rep=True throws this error, which seems like a different issue but I am unsure: Scan carry input and output got mismatched replication types [None] and [{'x'}]).

Full code:

import os
from functools import partial, reduce
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'

import jax
import jax.numpy as jnp
import numpy as np

from jax.experimental.custom_partitioning import custom_partitioning
from jax.sharding import PartitionSpec as P
from jax.sharding import Mesh
from jax.sharding import NamedSharding
from jax.sharding import PositionalSharding
from jax.tree_util import tree_map
import pytest

def make_custom_partitioning(part_func):
    @custom_partitioning
    def my_func(x):
        return jnp.sum(x)

    def partition(mesh, arg_shapes, result_shape):
        result_shardings = tree_map(lambda x: x.sharding, result_shape)
        arg_shardings = tree_map(lambda x: x.sharding, arg_shapes)
        assert isinstance(arg_shardings[0], NamedSharding)
        assert (None, 'x') == arg_shardings[0].spec
        return mesh, partial(part_func, axis_name='x'), result_shardings, arg_shardings

    def infer_sharding(mesh, arg_shapes, result_shape):
        return NamedSharding(mesh, P())

    def propagate_user_sharding(mesh, user_shape):
        return user_shape.sharding

    my_func.def_partition(partition, infer_sharding, propagate_user_sharding=propagate_user_sharding)
    return my_func


# This works fine:
def test_simple():
    def part_func(x, axis_name):
        return jax.lax.psum(jnp.sum(x), axis_name=axis_name)
    my_func = jax.jit(make_custom_partitioning(part_func))
    with Mesh(jax.devices(backend='cpu'), axis_names=('x',)) as mesh:
        array = jnp.ones([4,4])
        assert int(my_func(array)) == 16
        array = jax.device_put(array, NamedSharding(mesh, P(None,'x')))
        assert int(my_func(array)) == 16


# This doesn't:
def test_scan():
    def part_func(x, axis_name):
        def f(carry, part):
            carry += jax.lax.psum(jnp.sum(part), axis_name=axis_name)
            return carry, None
        return jax.lax.scan(f, 0, x)[0]
    my_func = jax.jit(make_custom_partitioning(part_func))
    with Mesh(jax.devices(backend='cpu'), axis_names=('x',)) as mesh:
        array = jnp.ones([4,4])
        assert int(my_func(array)) == 16
        array = jax.device_put(array, NamedSharding(mesh, P(None,'x')))
        # Crashes here with NameError: unbound axis name: x. The following axis names (e.g. defined by pmap) are available to collective operations: []
        assert int(my_func(array)) == 16

# It works under shard_map, as long as check_rep is False.
def test_shard_map():
    def part_func(x, axis_name):
        def f(carry, part):
            carry += jax.lax.psum(jnp.sum(part), axis_name=axis_name)
            return carry, None
        return jax.lax.scan(f, 0, x)[0]
    with Mesh(jax.devices(backend='cpu'), axis_names=('x',)) as mesh:
        array = jnp.ones([4,4])
        array = jax.device_put(array, NamedSharding(mesh, P(None,'x')))
        sharded = jax.jit(jax.experimental.shard_map.shard_map(partial(part_func, axis_name='x'), mesh, in_specs=(P(None,'x'),), out_specs=P(), check_rep=True))
        assert int(sharded(array)) == 16

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.2
python: 3.11.8 (main, Feb 12 2024, 14:50:05) [GCC 13.2.1 20230801]
jax.devices (1 total, 1 local): [cuda(id=0)]
process_count: 1
platform: uname_result(system='Linux', node='phenex', release='6.8.2-arch2-1', version='#1 SMP PREEMPT_DYNAMIC Thu, 28 Mar 2024 17:06:35 +0000', machine='x86_64')


$ nvidia-smi
Tue Apr 23 00:02:35 2024
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.67                 Driver Version: 550.67         CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|=========================================+========================+======================|
|   0  NVIDIA GeForce RTX 3090        Off |   00000000:01:00.0  On |                  N/A |
| 34%   38C    P2             33W /  350W |    1864MiB /  24576MiB |      2%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
@nshepperd nshepperd added the bug Something isn't working label Apr 22, 2024
@superbobry superbobry self-assigned this Apr 23, 2024
superbobry added a commit to superbobry/jax that referenced this issue Apr 23, 2024
See jax-ml#20864 for more context and the added test for a reproducer.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants