From 2c36aadb7c4d8ba7370e9a498dd7dfadef5dea7f Mon Sep 17 00:00:00 2001 From: Iurii Kemaev Date: Fri, 26 Nov 2021 06:06:02 -0800 Subject: [PATCH] Set up ReadTheDoc pages and add a few examples. Full documentation will be added later. PiperOrigin-RevId: 412437934 --- .gitignore | 1 + .readthedocs.yaml | 17 +++ chex/_src/asserts.py | 18 +-- chex/_src/fake.py | 11 +- docs/Makefile | 19 +++ docs/api.rst | 46 +++++++ docs/conf.py | 192 +++++++++++++++++++++++++++++ docs/ext/coverage_check.py | 81 ++++++++++++ docs/index.rst | 53 ++++++++ requirements/requirements-docs.txt | 11 ++ 10 files changed, 438 insertions(+), 11 deletions(-) create mode 100644 .readthedocs.yaml create mode 100644 docs/Makefile create mode 100644 docs/api.rst create mode 100644 docs/conf.py create mode 100644 docs/ext/coverage_check.py create mode 100644 docs/index.rst create mode 100644 requirements/requirements-docs.txt diff --git a/.gitignore b/.gitignore index c815ebbb..5fadbb99 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ build/ dist/ venv/ +docs/_build/ # Mac OS .DS_Store diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..f003f390 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,17 @@ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +version: 2 + +sphinx: + builder: html + configuration: docs/conf.py + fail_on_warning: false + +python: + version: 3.7 + install: + - requirements: requirements/requirements-docs.txt + - requirements: requirements/requirements.txt + - method: setuptools + path: . diff --git a/chex/_src/asserts.py b/chex/_src/asserts.py index f25b9a0f..cab11990 100644 --- a/chex/_src/asserts.py +++ b/chex/_src/asserts.py @@ -67,16 +67,17 @@ def clear_trace_counter(): def assert_max_traces(fn=None, n=None): - """Checks if a function is traced at most n times (inclusively). + """Checks if a function is traced at most `n` times (inclusively). - JAX re-traces JIT'ted function every time the structure of passed arguments + JAX re-traces jitted functions every time the structure of passed arguments changes. Often this behavior is inadvertent and leads to a significant performance drop which is hard to debug. This wrapper asserts that - the function is not re-traced more that `n` times during program execution. + the function is re-traced at most `n` times during program execution. Examples: - ``` + .. code-block:: python + @jax.jit @chex.assert_max_traces(n=1) def fn_sum_jitted(x, y): @@ -86,7 +87,6 @@ def fn_sub(x, y): return x - y fn_sub_pmapped = jax.pmap(chex.assert_max_retraces(fn_sub), n=10) - ``` More about tracing: https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html @@ -96,7 +96,7 @@ def fn_sub(x, y): n: maximum allowed number of retraces (non-negative). Returns: - Decorated f-n that throws exception when max. number of re-traces exceeded. + Decorated function that raises exception when it is re-traced `n+1`-st time. """ if not callable(fn) and n is None: # Passed n as a first argument. @@ -978,12 +978,12 @@ def assert_devices_available( n: required number of devices of a given type. devtype: type of devices, one of {'cpu', 'gpu', 'tpu'}. backend: type of backend to use (uses JAX default if `None`). - not_less_than: whether to check if number of devices _not less_ than - required `n` instead of precise comparison. + not_less_than: whether to check if the number of devices is not less than + required `n`, instead of precise comparison. Raises: AssertionError: if number of available device of a given type is not equal - or less than `n`. + or less than `n`. """ n_available = _ai.num_devices_available(devtype, backend=backend) devs = jax.devices(backend) diff --git a/chex/_src/fake.py b/chex/_src/fake.py index 8e502b6b..f8467306 100644 --- a/chex/_src/fake.py +++ b/chex/_src/fake.py @@ -60,7 +60,8 @@ def set_n_cpu_devices(n: Optional[int] = None): See https://github.com/google/jax/issues/1408. Args: - n: required number of CPU devices (`FLAGS.chex_n_cpu_devices` if `None`). + n: a required number of CPU devices (`FLAGS.chex_n_cpu_devices` is used by + default). Raises: RuntimeError: if XLA backends were already initialized. @@ -175,6 +176,8 @@ def fake_jit(enable_patching: bool = True): Can be used either as a context managed scope: + .. code-block:: python + with chex.fake_jit(): @jax.jit def foo(x): @@ -182,11 +185,15 @@ def foo(x): or by calling `start` and `stop`: + .. code-block:: python + fake_jit_context = chex.fake_jit() fake_jit.context.start() + @jax.jit def foo(x): - ... + ... + fake_jit.context.stop() Args: diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 00000000..51285967 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,19 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line. +SPHINXOPTS = +SPHINXBUILD = sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/api.rst b/docs/api.rst new file mode 100644 index 00000000..0e223663 --- /dev/null +++ b/docs/api.rst @@ -0,0 +1,46 @@ +Assertions +========== + +.. currentmodule:: chex + +.. autosummary:: + + assert_devices_available + assert_max_traces + assert_scalar + + +JaxAssertions +~~~~~~~~~~~~~ + +.. autofunction:: assert_devices_available +.. autofunction:: assert_max_traces + + +GenericAssertions +~~~~~~~~~~~~~~~~~ + +.. autofunction:: assert_scalar + + +Fakes +===== + +.. currentmodule:: chex + +.. autosummary:: + + set_n_cpu_devices + fake_jit + + +Devices +~~~~~~~ + +.. autofunction:: set_n_cpu_devices + + +Transformations +~~~~~~~~~~~~~~~ + +.. autofunction:: fake_jit diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 00000000..ec3db293 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,192 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Configuration file for the Sphinx documentation builder.""" + +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# http://www.sphinx-doc.org/en/master/config + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. + +# pylint: disable=g-bad-import-order +# pylint: disable=g-import-not-at-top +import inspect +import os +import sys +import typing + + +def _add_annotations_import(path): + """Appends a future annotations import to the file at the given path.""" + with open(path) as f: + contents = f.read() + if contents.startswith('from __future__ import annotations'): + # If we run sphinx multiple times then we will append the future import + # multiple times too. + return + + assert contents.startswith('#'), (path, contents.split('\n')[0]) + with open(path, 'w') as f: + # NOTE: This is subtle and not unit tested, we're prefixing the first line + # in each Python file with this future import. It is important to prefix + # not insert a newline such that source code locations are accurate (we link + # to GitHub). The assertion above ensures that the first line in the file is + # a comment so it is safe to prefix it. + f.write('from __future__ import annotations ') + f.write(contents) + + +def _recursive_add_annotations_import(): + for path, _, files in os.walk('../chex/'): + for file in files: + if file.endswith('.py'): + _add_annotations_import(os.path.abspath(os.path.join(path, file))) + +if 'READTHEDOCS' in os.environ: + _recursive_add_annotations_import() + +typing.get_type_hints = lambda obj, *unused: obj.__annotations__ +sys.path.insert(0, os.path.abspath('../')) +sys.path.append(os.path.abspath('ext')) + +import chex +import sphinxcontrib.katex as katex + +# -- Project information ----------------------------------------------------- + +project = 'Chex' +copyright = '2021, DeepMind' # pylint: disable=redefined-builtin +author = 'Chex Contributors' + +# -- General configuration --------------------------------------------------- + +master_doc = 'index' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.autosummary', + 'sphinx.ext.doctest', + 'sphinx.ext.inheritance_diagram', + 'sphinx.ext.intersphinx', + 'sphinx.ext.linkcode', + 'sphinx.ext.napoleon', + 'sphinxcontrib.bibtex', + 'sphinxcontrib.katex', + 'sphinx_autodoc_typehints', + 'sphinx_rtd_theme', + 'coverage_check', + 'myst_nb', # This is used for the .ipynb notebooks +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +# -- Options for autodoc ----------------------------------------------------- + +autodoc_default_options = { + 'member-order': 'bysource', + 'special-members': True, + 'exclude-members': '__repr__, __str__, __weakref__', +} + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = [] +# html_favicon = '_static/favicon.ico' + +# -- Options for myst ------------------------------------------------------- + +jupyter_execute_notebooks = 'force' +execution_allow_errors = False + +# -- Options for katex ------------------------------------------------------ + +# See: https://sphinxcontrib-katex.readthedocs.io/en/0.4.1/macros.html +latex_macros = r""" + \def \d #1{\operatorname{#1}} +""" + +# Translate LaTeX macros to KaTeX and add to options for HTML builder +katex_macros = katex.latex_defs_to_katex_macros(latex_macros) +katex_options = 'macros: {' + katex_macros + '}' + +# Add LaTeX macros for LATEX builder +latex_elements = {'preamble': latex_macros} + +# -- Source code links ------------------------------------------------------- + + +def linkcode_resolve(domain, info): + """Resolve a GitHub URL corresponding to Python object.""" + if domain != 'py': + return None + + try: + mod = sys.modules[info['module']] + except ImportError: + return None + + obj = mod + try: + for attr in info['fullname'].split('.'): + obj = getattr(obj, attr) + except AttributeError: + return None + else: + obj = inspect.unwrap(obj) + + try: + filename = inspect.getsourcefile(obj) + except TypeError: + return None + + try: + source, lineno = inspect.getsourcelines(obj) + except OSError: + return None + + # TODO(slebedev): support tags after we release an initial version. + return 'https://github.com/deepmind/chex/tree/master/chex/%s#L%d#L%d' % ( + os.path.relpath(filename, start=os.path.dirname( + chex.__file__)), lineno, lineno + len(source) - 1) + + +# -- Intersphinx configuration ----------------------------------------------- + +intersphinx_mapping = { + 'jax': ('https://jax.readthedocs.io/en/latest/', None), +} + +source_suffix = ['.rst', '.md', '.ipynb'] diff --git a/docs/ext/coverage_check.py b/docs/ext/coverage_check.py new file mode 100644 index 00000000..7de468c1 --- /dev/null +++ b/docs/ext/coverage_check.py @@ -0,0 +1,81 @@ +# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Asserts all public symbols are covered in the docs.""" + +import inspect +import types +from typing import Any, Mapping, Sequence, Tuple + +import chex as _module +from sphinx import application +from sphinx import builders +from sphinx import errors + + +def find_internal_python_modules( + root_module: types.ModuleType,) -> Sequence[Tuple[str, types.ModuleType]]: + """Returns `(name, module)` for all submodules under `root_module`.""" + modules = set([(root_module.__name__, root_module)]) + visited = set() + to_visit = [root_module] + + while to_visit: + mod = to_visit.pop() + visited.add(mod) + + for name in dir(mod): + obj = getattr(mod, name) + if inspect.ismodule(obj) and obj not in visited: + if obj.__name__.startswith(_module.__name__): + if "_src" not in obj.__name__: + to_visit.append(obj) + modules.add((obj.__name__, obj)) + + return sorted(modules) + + +def get_public_symbols() -> Sequence[Tuple[str, types.ModuleType]]: + names = set() + for module_name, module in find_internal_python_modules(_module): + for name in module.__all__: + names.add(module_name + "." + name) + return tuple(names) + + +class CoverageCheck(builders.Builder): + """Builder that checks all public symbols are included.""" + + name = "coverage_check" + + def get_outdated_docs(self) -> str: + return "coverage_check" + + def write(self, *ignored: Any) -> None: + pass + + def finish(self) -> None: + documented_objects = frozenset(self.env.domaindata["py"]["objects"]) + undocumented_objects = set(get_public_symbols()) - documented_objects + if undocumented_objects: + undocumented_objects = tuple(sorted(undocumented_objects)) + raise errors.SphinxError( + "All public symbols must be included in our documentation, did you " + "forget to add an entry to `api.rst`?\n" + f"Undocumented symbols: {undocumented_objects}") + + +def setup(app: application.Sphinx) -> Mapping[str, Any]: + app.add_builder(CoverageCheck) + return dict(version=_module.__version__, parallel_read_safe=True) diff --git a/docs/index.rst b/docs/index.rst new file mode 100644 index 00000000..00d8fcfb --- /dev/null +++ b/docs/index.rst @@ -0,0 +1,53 @@ +:github_url: https://github.com/deepmind/chex/tree/master/docs + +Chex +----- + +Chex is a library of utilities for helping to write reliable JAX code. + +This includes utils to help: + +* Instrument your code (e.g. assertions) +* Debug (e.g. transforming pmaps in vmaps within a context manager). +* Test JAX code across many variants (e.g. jitted vs non-jitted). + +Installation +------------ + +Chex can be installed with pip directly from github, with the following command: + +``pip install git+git://github.com/deepmind/chex.git`` + +or from PyPI: + +``pip install chex`` + +.. toctree:: + :caption: API Documentation + :maxdepth: 2 + + api + + +Contribute +---------- + +- `Issue tracker `_ +- `Source code `_ + +Support +------- + +If you are having issues, please let us know by filing an issue on our +`issue tracker `_. + +License +------- + +Chex is licensed under the Apache 2.0 License. + + +Indices and Tables +================== + +* :ref:`genindex` diff --git a/requirements/requirements-docs.txt b/requirements/requirements-docs.txt new file mode 100644 index 00000000..ef048202 --- /dev/null +++ b/requirements/requirements-docs.txt @@ -0,0 +1,11 @@ +sphinx==3.3.0 +sphinx_rtd_theme==0.5.0 +sphinxcontrib-katex==0.7.1 +sphinxcontrib-bibtex==1.0.0 +sphinx-autodoc-typehints==1.11.1 +IPython==7.16.1 +ipykernel==5.3.4 +pandoc==1.0.2 +myst_nb==0.13.1 +docutils==0.16 +matplotlib==3.5.0