Skip to content

Commit

Permalink
Enable parallel ops in fake contexts
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 426865572
  • Loading branch information
stompchicken authored and ChexDev committed Mar 7, 2022
1 parent 1e255f8 commit 395e02a
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 31 deletions.
95 changes: 64 additions & 31 deletions chex/_src/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from unittest import mock
from absl import flags
import jax
import jax.numpy as jnp


FLAGS = flags.FLAGS
flags.DEFINE_integer('chex_n_cpu_devices', 1,
Expand Down Expand Up @@ -92,6 +92,36 @@ def _fake_jit(fn, *unused_args, **unused_kwargs):
return fn


def _ignore_axis_index_groups(fn):
"""Wrapper that forces axis_index_groups to be None.
This is to avoid problems within fake_pmap where parallel operations are
performed with vmap, rather than pmap. Parallel operations where
`axis_index_groups` is not `None` are not currently supported under vmap.
Args:
fn: the function to wrap
Returns:
a wrapped function that forces any keyword argument named
`axis_index_groups` to be None
"""
@functools.wraps(fn)
def _fake(*args, axis_index_groups=None, **kwargs):
del axis_index_groups
return fn(*args, axis_index_groups=None, **kwargs)
return _fake


_fake_all_gather = _ignore_axis_index_groups(jax.lax.all_gather)
_fake_all_to_all = _ignore_axis_index_groups(jax.lax.all_to_all)
_fake_psum = _ignore_axis_index_groups(jax.lax.psum)
_fake_pmean = _ignore_axis_index_groups(jax.lax.pmean)
_fake_pmax = _ignore_axis_index_groups(jax.lax.pmax)
_fake_pmin = _ignore_axis_index_groups(jax.lax.pmin)
_fake_pswapaxes = _ignore_axis_index_groups(jax.lax.pswapaxes)


@functools.wraps(jax.pmap)
def _fake_pmap(fn,
axis_name: Optional[Any] = None,
Expand Down Expand Up @@ -136,27 +166,12 @@ def wrapped_fn(*args, **kwargs):
vmapped_fn = jax.jit(vmapped_fn)

output = vmapped_fn(*call_args)

return output

return wrapped_fn


def _identity(x, *unused_args, **unused_kwargs):
return x


_fake_psum = functools.wraps(jax.lax.psum)(_identity)
_fake_pmean = functools.wraps(jax.lax.pmean)(_identity)
_fake_pmax = functools.wraps(jax.lax.pmax)(_identity)
_fake_pmin = functools.wraps(jax.lax.pmin)(_identity)


@functools.wraps(jax.lax.all_gather)
def _fake_all_gather(x, *unused_args, **unused_kwargs):
add_leading_dim = lambda t: t[jnp.newaxis]
return jax.tree_map(add_leading_dim, x)


class FakeContext(contextlib.ExitStack):

def start(self):
Expand Down Expand Up @@ -223,15 +238,16 @@ def _jax_disable_jit():
return stack


def fake_pmap(enable_patching: bool = True,
jit_result: bool = False) -> FakeContext:
def fake_pmap(
enable_patching: bool = True,
jit_result: bool = False,
axis_name: Optional[Any] = None,
ignore_axis_index_groups: bool = False,
) -> FakeContext:
"""Context manager for patching `jax.pmap` with `jax.vmap`.
This is intended to be used as a debugging tool to programmatically replace
pmap transformations with a non-parallel vmap transformation. Beware that the
output is *not* guaranteed to be identical with `jax.pmap`! In particular, all
`jax.lax.p*` operations are replaced with identity maps when `fake_pmap` is
used.
pmap transformations with a non-parallel vmap transformation.
Can be used either as a context managed scope:
Expand All @@ -257,21 +273,38 @@ def foo(x):
enable_patching: Whether to patch `jax.pmap`.
jit_result: Whether the transformed function should be jitted despite not
being pmapped.
axis_name: axis name to use to parallel operations.
ignore_axis_index_groups: Whether to force any parallel operation within the
context to set `axis_index_groups` to be None. This is a compatibility
option to allow users of the axis_index_groups parameter to run under the
fake_pmap context. This feature is not currently supported in vmap, and
will fail, so we force the parameter to be `None`.
*Warning*: This will produce different results to running under `jax.pmap`
Returns:
Context where `jax.pmap` is patched with `jax.vmap`.
"""
# Improve implementation to automatically track JAX collectives development.
stack = FakeContext()
if enable_patching:
stack.enter_context(
mock.patch('jax.pmap',
functools.partial(_fake_pmap, jit_result=jit_result)))
stack.enter_context(mock.patch('jax.lax.psum', _fake_psum))
stack.enter_context(mock.patch('jax.lax.pmean', _fake_pmean))
stack.enter_context(mock.patch('jax.lax.pmax', _fake_pmax))
stack.enter_context(mock.patch('jax.lax.pmin', _fake_pmin))
stack.enter_context(mock.patch('jax.lax.all_gather', _fake_all_gather))
patched_pmap = functools.partial(
_fake_pmap,
axis_name=axis_name,
jit_result=jit_result)
stack.enter_context(mock.patch('jax.pmap', patched_pmap))

if ignore_axis_index_groups:
stack.enter_context(mock.patch('jax.lax.all_gather', _fake_all_gather))
stack.enter_context(mock.patch('jax.lax.all_to_all', _fake_all_to_all))
stack.enter_context(mock.patch('jax.lax.psum', _fake_psum))
stack.enter_context(mock.patch('jax.lax.pmean', _fake_pmean))
stack.enter_context(mock.patch('jax.lax.pmax', _fake_pmax))
stack.enter_context(mock.patch('jax.lax.pmin', _fake_pmin))
stack.enter_context(mock.patch('jax.lax.pswapaxes', _fake_pswapaxes))
else:
# Use default implementations
pass

return stack


Expand Down
59 changes: 59 additions & 0 deletions chex/_src/fake_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,65 @@ def foo(x, y, flag=True):
asserts.assert_trees_all_close(
overidden_foo(x=inputs, y=inputs), expected)

def test_parallel_ops_equivalence(self):
"""Test equivalence between parallel operations using pmap and vmap."""
num_devices = len(jax.devices())
inputs = jax.random.uniform(shape=(num_devices, num_devices, 2),
key=jax.random.PRNGKey(1))

def test_equivalence(fn):
with fake.fake_pmap(enable_patching=False):
outputs1 = jax.pmap(fn, axis_name='i', axis_size=num_devices)(inputs)
with fake.fake_pmap(enable_patching=True):
outputs2 = jax.pmap(fn, axis_name='i', axis_size=num_devices)(inputs)
with fake.fake_pmap(enable_patching=True, jit_result=True):
outputs3 = jax.pmap(fn, axis_name='i', axis_size=num_devices)(inputs)
asserts.assert_trees_all_close(outputs1, outputs2, outputs3)

parallel_ops_and_kwargs = [
(jax.lax.psum, {}),
(jax.lax.pmax, {}),
(jax.lax.pmin, {}),
(jax.lax.pmean, {}),
(jax.lax.all_gather, {}),
(jax.lax.all_to_all, {
'split_axis': 0,
'concat_axis': 1
}),
(jax.lax.ppermute, {
'perm': [(x, (x + 1) % num_devices) for x in range(num_devices)]
}),
]

def fn(op, kwargs, x, y=2.0):
return op(x * y, axis_name='i', **kwargs)
partial_fn = functools.partial(fn, y=4.0)
lambda_fn = lambda op, kwargs, x: fn(op, kwargs, x, y=5.0)

for op, kwargs in parallel_ops_and_kwargs:
test_equivalence(functools.partial(fn, op, kwargs))
test_equivalence(functools.partial(fn, op, kwargs, y=3.0))
test_equivalence(functools.partial(partial_fn, op, kwargs))
test_equivalence(functools.partial(lambda_fn, op, kwargs))

# def test_fake_parallel_axis(self):
# inputs = jnp.ones(shape=(2, 2))
# with fake.fake_pmap(fake_parallel_axis=False):
# @jax.pmap
# def fn(x):
# asserts.assert_shape(x, (2,))
# return 2.0 * x
#
# outputs = fn(inputs)
#
# with fake.fake_pmap(fake_parallel_axis=True):
# @jax.pmap
# def fn(x):
# asserts.assert_shape(x, (2, 2,))
# return 2.0 * x
#
# outputs = fn(inputs)


class _Counter():
"""Counts how often an instance is called."""
Expand Down

0 comments on commit 395e02a

Please sign in to comment.