You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I encountered this problem while implementing a ring attention algorithm using custom_partitioning. A custom partitioning of jnp.sum using the following contrived partitioning implementation (see full self contained code at the end of this report) throws an error, saying the axis name cannot be found:
defpart_func(x, axis_name):
deff(carry, part):
carry+=jax.lax.psum(jnp.sum(part), axis_name=axis_name)
returncarry, Nonereturnjax.lax.scan(f, 0, x)[0]
# NameError: unbound axis name: x. The following axis names (e.g. defined by pmap) are available to collective operations: []
However, it works fine if the part_func does not contain a scan, such as:
This also works if shard_map is used to directly call part_func instead of via custom partitioning, but only if check_rep=False (check_rep=True throws this error, which seems like a different issue but I am unsure: Scan carry input and output got mismatched replication types [None] and [{'x'}]).
Full code:
importosfromfunctoolsimportpartial, reduceos.environ["XLA_FLAGS"] ='--xla_force_host_platform_device_count=2'importjaximportjax.numpyasjnpimportnumpyasnpfromjax.experimental.custom_partitioningimportcustom_partitioningfromjax.shardingimportPartitionSpecasPfromjax.shardingimportMeshfromjax.shardingimportNamedShardingfromjax.shardingimportPositionalShardingfromjax.tree_utilimporttree_mapimportpytestdefmake_custom_partitioning(part_func):
@custom_partitioningdefmy_func(x):
returnjnp.sum(x)
defpartition(mesh, arg_shapes, result_shape):
result_shardings=tree_map(lambdax: x.sharding, result_shape)
arg_shardings=tree_map(lambdax: x.sharding, arg_shapes)
assertisinstance(arg_shardings[0], NamedSharding)
assert (None, 'x') ==arg_shardings[0].specreturnmesh, partial(part_func, axis_name='x'), result_shardings, arg_shardingsdefinfer_sharding(mesh, arg_shapes, result_shape):
returnNamedSharding(mesh, P())
defpropagate_user_sharding(mesh, user_shape):
returnuser_shape.shardingmy_func.def_partition(partition, infer_sharding, propagate_user_sharding=propagate_user_sharding)
returnmy_func# This works fine:deftest_simple():
defpart_func(x, axis_name):
returnjax.lax.psum(jnp.sum(x), axis_name=axis_name)
my_func=jax.jit(make_custom_partitioning(part_func))
withMesh(jax.devices(backend='cpu'), axis_names=('x',)) asmesh:
array=jnp.ones([4,4])
assertint(my_func(array)) ==16array=jax.device_put(array, NamedSharding(mesh, P(None,'x')))
assertint(my_func(array)) ==16# This doesn't:deftest_scan():
defpart_func(x, axis_name):
deff(carry, part):
carry+=jax.lax.psum(jnp.sum(part), axis_name=axis_name)
returncarry, Nonereturnjax.lax.scan(f, 0, x)[0]
my_func=jax.jit(make_custom_partitioning(part_func))
withMesh(jax.devices(backend='cpu'), axis_names=('x',)) asmesh:
array=jnp.ones([4,4])
assertint(my_func(array)) ==16array=jax.device_put(array, NamedSharding(mesh, P(None,'x')))
# Crashes here with NameError: unbound axis name: x. The following axis names (e.g. defined by pmap) are available to collective operations: []assertint(my_func(array)) ==16# It works under shard_map, as long as check_rep is False.deftest_shard_map():
defpart_func(x, axis_name):
deff(carry, part):
carry+=jax.lax.psum(jnp.sum(part), axis_name=axis_name)
returncarry, Nonereturnjax.lax.scan(f, 0, x)[0]
withMesh(jax.devices(backend='cpu'), axis_names=('x',)) asmesh:
array=jnp.ones([4,4])
array=jax.device_put(array, NamedSharding(mesh, P(None,'x')))
sharded=jax.jit(jax.experimental.shard_map.shard_map(partial(part_func, axis_name='x'), mesh, in_specs=(P(None,'x'),), out_specs=P(), check_rep=True))
assertint(sharded(array)) ==16
System info (python version, jaxlib version, accelerator, etc.)
Description
I encountered this problem while implementing a ring attention algorithm using custom_partitioning. A custom partitioning of jnp.sum using the following contrived partitioning implementation (see full self contained code at the end of this report) throws an error, saying the axis name cannot be found:
However, it works fine if the part_func does not contain a scan, such as:
This also works if shard_map is used to directly call part_func instead of via custom partitioning, but only if check_rep=False (check_rep=True throws this error, which seems like a different issue but I am unsure:
Scan carry input and output got mismatched replication types [None] and [{'x'}]
).Full code:
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: