Skip to content

Commit

Permalink
Add asserts, variants, and pytypes modules to the RTD docs.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 414789319
  • Loading branch information
hbq1 authored and ChexDev committed Dec 8, 2021
1 parent 06d6b2e commit bd11ace
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 89 deletions.
22 changes: 11 additions & 11 deletions chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,21 @@


def mappable_dataclass(cls):
"""Exposes dataclass as `collections.abc.Mapping` descendent.
"""Exposes dataclass as ``collections.abc.Mapping`` descendent.
Allows to traverse dataclasses in methods from `dm-tree` library.
NOTE: changes dataclasses constructor to dict-type
(i.e. positional args aren't supported; however can use generators/iterables).
Args:
cls: dataclass to mutate.
cls: A dataclass to mutate.
Returns:
Mutated dataclass implementing `collections.abc.Mapping` interface.
Mutated dataclass implementing ``collections.abc.Mapping`` interface.
"""
if not dataclasses.is_dataclass(cls):
raise ValueError(f"Expected dataclass, got {cls} (change wrappers order?)")
raise ValueError(f"Expected dataclass, got {cls} (change wrappers order?).")

# Define methods for compatibility with `collections.abc.Mapping`.
setattr(cls, "__getitem__", lambda self, x: self.__dict__[x])
Expand Down Expand Up @@ -100,13 +100,13 @@ def dataclass(
to operate on the class when made immutable (frozen=True).
Args:
cls: class to decorate
init: See :py:func:`dataclasses.dataclass`
repr: See :py:func:`dataclasses.dataclass`
eq: See :py:func:`dataclasses.dataclass`
order: See :py:func:`dataclasses.dataclass`
unsafe_hash: See :py:func:`dataclasses.dataclass`
frozen: See :py:func:`dataclasses.dataclass`
cls: A class to decorate.
init: See :py:func:`dataclasses.dataclass`.
repr: See :py:func:`dataclasses.dataclass`.
eq: See :py:func:`dataclasses.dataclass`.
order: See :py:func:`dataclasses.dataclass`.
unsafe_hash: See :py:func:`dataclasses.dataclass`.
frozen: See :py:func:`dataclasses.dataclass`.
mappable_dataclass: If True (the default), methods to make the class
implement the :py:class:`collections.abc.Mappable` interface will be
generated and the class will include :py:class:`collections.abc.Mappable`
Expand Down
151 changes: 87 additions & 64 deletions chex/_src/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import functools
import inspect
import itertools
from typing import Any, Sequence
import unittest

from absl import flags
Expand All @@ -29,18 +30,12 @@
import jax.numpy as jnp
import toolz


FLAGS = flags.FLAGS
flags.DEFINE_bool(
"chex_skip_pmap_variant_if_single_device", True,
"Whether to skip pmap variant if only one device is available.")


# `@chex.variants` returns a generator producing one test per variant.
# Therefore, users' TestCase class must support dynamic unrolling of these
# generators during module import. It is implemented and well-tested in
# `parameterized.TestCase`, hence we alias it as `variants.TestCase`.
#
# We choose to subclass instead of a simple alias, as Python doesn't allow
# multiple inheritance from the same class, and users may want to subclass their
# tests from both `chex.TestCase` and `parameterized.TestCase`.
Expand All @@ -49,6 +44,15 @@
# instead of `variants.TestCase` or `parameterized.TestCase`. If a base class
# doesn't support this feature variant test fails with a corresponding error.
class TestCase(parameterized.TestCase):
"""A class for Chex tests that use variants.

See the docstring for ``chex.variants`` for more information.

Note: ``chex.variants`` returns a generator producing one test per variant.
Therefore, the used test class must support dynamic unrolling of these
generators during module import. It is implemented (and battle-tested) in
``absl.parameterized.TestCase``, and here we subclass from it.
"""

def variant(self, *args, **kwargs):
"""Raises a RuntimeError if not overriden or redefined."""
Expand All @@ -57,6 +61,12 @@ def variant(self, *args, **kwargs):


class ChexVariantType(enum.Enum):
"""An enumeration of available Chex variants.

Use ``self.variant.type`` to get type of the current test variant.
See the docstring of ``chex.variants`` for more information.
"""

WITH_JIT = 1
WITHOUT_JIT = 2
WITH_DEVICE = 3
Expand All @@ -70,8 +80,20 @@ def __str__(self) -> str:
tree_map = tree_util.tree_map


def params_product(*params_lists, named=False):
"""Generates a cartesian product of params_lists."""
def params_product(*params_lists: Sequence[Sequence[Any]],
named: bool = False) -> Sequence[Sequence[Any]]:
"""Generates a cartesian product of `params_lists`.

See tests from ``variants_test.py`` for examples of usage.

Args:
*params_lists: A list of params combinations.
named: Whether to generate test names (for
`absl.parameterized.named_parameters(...)`).

Returns:
A cartesian product of `params_lists` combinations.
"""

def generate():
for combination in itertools.product(*params_lists):
Expand Down Expand Up @@ -212,12 +234,13 @@ def __iter__(self):


@toolz.curry
def _variants_fn(test_object, **which_variants):
def _variants_fn(test_object, **which_variants) -> VariantsTestCaseGenerator:
"""Implements `variants` and `all_variants`."""

# Convert keys to enum entries.
which_variants = {
ChexVariantType[name.upper()]: var for name, var in which_variants.items()
ChexVariantType[name.upper()]: var
for name, var in which_variants.items()
}
if isinstance(test_object, VariantsTestCaseGenerator):
# Merge variants for nested wrappers.
Expand All @@ -235,58 +258,60 @@ def variants(test_method,
without_jit: bool = False,
with_device: bool = False,
without_device: bool = False,
with_pmap: bool = False):
with_pmap: bool = False) -> VariantsTestCaseGenerator:
# pylint: enable=redefined-outer-name
"""Decorates a test to expose Chex variants.

The decorated test has access to a decorator called `self.variant`, which
The decorated test has access to a decorator called ``self.variant``, which
may be applied to functions to test different JAX behaviors. Consider:

```python
@chex.variants(with_jit=True, without_jit=True)
def test(self):
@self.variant
def f(x, y):
return x + y
.. code-block:: python

@chex.variants(with_jit=True, without_jit=True)
def test(self):
@self.variant
def f(x, y):
return x + y

self.assertEqual(f(1, 2), 3)
```
self.assertEqual(f(1, 2), 3)

In this example, the function `test` will be called twice: once with `f`
In this example, the function ``test`` will be called twice: once with `f`
jitted (i.e. using `jax.jit`) and another where `f` is not jitted.

Variants `with_jit=True` and `with_pmap=True` accept additional specific to
them arguments. Example:
```python
@chex.variants(with_jit=True)
def test(self):
@self.variant(static_argnums=(1,))
def f(x, y):
# `y` is not traced.
return x + y

self.assertEqual(f(1, 2), 3)
```
.. code-block:: python

@chex.variants(with_jit=True)
def test(self):
@self.variant(static_argnums=(1,))
def f(x, y):
# `y` is not traced.
return x + y

self.assertEqual(f(1, 2), 3)

Variant `with_pmap=True` also accepts `broadcast_args_to_devices`
(whether to broadcast each input argument to all participating devices),
`reduce_fn` (a function to apply to results of pmapped `fn`), and
`n_devices` (number of devices to use in the `pmap` computation).
See the docstring of `_with_pmap` for more details (including default values).

If used with `absl.testing.parameterized`, @chex.variants must wrap it:
```python
@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters('test', *args)
def test(self, *args):
...
```
If used with ``absl.testing.parameterized``, `@chex.variants` must wrap it:

.. code-block:: python

Tests that use this wrapper must be inherited from `parameterized.TestCase`.
For more examples see 'variants_test.py'.
@chex.variants(with_jit=True, without_jit=True)
@parameterized.named_parameters('test', *args)
def test(self, *args):
...

Tests that use this wrapper must be inherited from ``parameterized.TestCase``.
For more examples see ``variants_test.py``.

Args:
test_method: Test method to decorate.
test_method: A test method to decorate.
with_jit: Whether to test with `jax.jit`.
without_jit: Whether to test without `jax.jit`. Any jit compilation done
within the test method will not be affected.
Expand All @@ -298,7 +323,7 @@ def test(self, *args):
across devices.

Returns:
Decorated test_method.
A decorated ``test_method``.
"""
return _variants_fn(
test_method,
Expand All @@ -316,9 +341,9 @@ def all_variants(test_method,
without_jit: bool = True,
with_device: bool = True,
without_device: bool = True,
with_pmap: bool = True):
with_pmap: bool = True) -> VariantsTestCaseGenerator:
# pylint: enable=redefined-outer-name
"""Equivalent to `variants` but with flipped defaults."""
"""Equivalent to ``chex.variants`` but with flipped defaults."""
return _variants_fn(
test_method,
with_jit=with_jit,
Expand Down Expand Up @@ -373,10 +398,7 @@ def wrapper(*args, **kwargs):

@toolz.curry
@check_variant_arguments
def _with_device(fn,
ignore_argnums=(),
static_argnums=(),
**unused_kwargs):
def _with_device(fn, ignore_argnums=(), static_argnums=(), **unused_kwargs):
"""Variant that applies `jax.device_put` to the args of fn."""

if isinstance(ignore_argnums, int):
Expand Down Expand Up @@ -439,30 +461,30 @@ def _with_pmap(fn,
"""Variant that applies `jax.pmap` to fn.

Args:
fn: a function to wrap.
broadcast_args_to_devices: whether to broadcast `fn` args to pmap format
fn: A function to wrap.
broadcast_args_to_devices: Whether to broadcast `fn` args to pmap format
(i.e. pmapped axes' sizes == a number of devices).
reduce_fn: a function to apply to outputs of `fn`.
n_devices: a number of devices to use (can specify a `backend` if required).
axis_name: passed to `pmap`.
devices: passed to `pmap`.
in_axes: passed to `pmap`.
static_broadcasted_argnums: passed to `pmap`.
static_argnums: alias of static_broadcasted_argnums.
backend: passed to `pmap`.
**unused_kwargs: unused kwargs (e.g. related to other variants).
reduce_fn: A function to apply to outputs of `fn`.
n_devices: A number of devices to use (can specify a `backend` if required).
axis_name: An argument for `pmap`.
devices: An argument for `pmap`.
in_axes: An argument for `pmap`.
static_broadcasted_argnums: An argument for `pmap`.
static_argnums: An alias of ``static_broadcasted_argnums``.
backend: An argument for `pmap`.
**unused_kwargs: Unused kwargs (e.g. related to other variants).

Returns:
Wrapped `fn` that accepts `args` and `kwargs` and returns a superposition of
`reduce_fn` and `fn` applied to them.

Raises:
ValueError: if `broadcast_args_to_devices` used with `in_axes` or
`static_broadcasted_argnums`;
if number of available devices is less than required;
if pmappable arg axes' sizes are not equal to the number of devices.
SkipTest: if the flag chex_skip_pmap_variant_if_single_device is set and
there is only one device available.
ValueError: If `broadcast_args_to_devices` used with `in_axes` or
`static_broadcasted_argnums`; if number of available devices is less than
required; if pmappable arg axes' sizes are not equal to the number of
devices.
SkipTest: If the flag ``chex_skip_pmap_variant_if_single_device`` is set and
there is only one device available.
"""
if (FLAGS["chex_skip_pmap_variant_if_single_device"].value and
jax.device_count() < 2):
Expand Down Expand Up @@ -547,6 +569,7 @@ def wrapper(*args: pytypes.ArrayTree, **kwargs: pytypes.ArrayTree):

return wrapper


_variant_decorators = dict({
ChexVariantType.WITH_JIT: _with_jit,
ChexVariantType.WITHOUT_JIT: _without_jit,
Expand Down
Loading

0 comments on commit bd11ace

Please sign in to comment.