Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set up ReadTheDoc pages and add a few examples. #118

Merged
merged 1 commit into from
Nov 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
build/
dist/
venv/
docs/_build/

# Mac OS
.DS_Store
Expand Down
17 changes: 17 additions & 0 deletions .readthedocs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Read the Docs configuration file
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details

version: 2

sphinx:
builder: html
configuration: docs/conf.py
fail_on_warning: false

python:
version: 3.7
install:
- requirements: requirements/requirements-docs.txt
- requirements: requirements/requirements.txt
- method: setuptools
path: .
18 changes: 9 additions & 9 deletions chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,17 @@ def clear_trace_counter():


def assert_max_traces(fn=None, n=None):
"""Checks if a function is traced at most n times (inclusively).
"""Checks if a function is traced at most `n` times (inclusively).

JAX re-traces JIT'ted function every time the structure of passed arguments
JAX re-traces jitted functions every time the structure of passed arguments
changes. Often this behavior is inadvertent and leads to a significant
performance drop which is hard to debug. This wrapper asserts that
the function is not re-traced more that `n` times during program execution.
the function is re-traced at most `n` times during program execution.

Examples:

```
.. code-block:: python

@jax.jit
@chex.assert_max_traces(n=1)
def fn_sum_jitted(x, y):
Expand All @@ -86,7 +87,6 @@ def fn_sub(x, y):
return x - y

fn_sub_pmapped = jax.pmap(chex.assert_max_retraces(fn_sub), n=10)
```

More about tracing:
https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html
Expand All @@ -96,7 +96,7 @@ def fn_sub(x, y):
n: maximum allowed number of retraces (non-negative).

Returns:
Decorated f-n that throws exception when max. number of re-traces exceeded.
Decorated function that raises exception when it is re-traced `n+1`-st time.
"""
if not callable(fn) and n is None:
# Passed n as a first argument.
Expand Down Expand Up @@ -978,12 +978,12 @@ def assert_devices_available(
n: required number of devices of a given type.
devtype: type of devices, one of {'cpu', 'gpu', 'tpu'}.
backend: type of backend to use (uses JAX default if `None`).
not_less_than: whether to check if number of devices _not less_ than
required `n` instead of precise comparison.
not_less_than: whether to check if the number of devices is not less than
required `n`, instead of precise comparison.

Raises:
AssertionError: if number of available device of a given type is not equal
or less than `n`.
or less than `n`.
"""
n_available = _ai.num_devices_available(devtype, backend=backend)
devs = jax.devices(backend)
Expand Down
11 changes: 9 additions & 2 deletions chex/_src/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def set_n_cpu_devices(n: Optional[int] = None):
See https://github.com/google/jax/issues/1408.

Args:
n: required number of CPU devices (`FLAGS.chex_n_cpu_devices` if `None`).
n: a required number of CPU devices (`FLAGS.chex_n_cpu_devices` is used by
default).

Raises:
RuntimeError: if XLA backends were already initialized.
Expand Down Expand Up @@ -175,18 +176,24 @@ def fake_jit(enable_patching: bool = True):

Can be used either as a context managed scope:

.. code-block:: python

with chex.fake_jit():
@jax.jit
def foo(x):
...

or by calling `start` and `stop`:

.. code-block:: python

fake_jit_context = chex.fake_jit()
fake_jit.context.start()

@jax.jit
def foo(x):
...
...

fake_jit.context.stop()

Args:
Expand Down
19 changes: 19 additions & 0 deletions docs/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Minimal makefile for Sphinx documentation
#

# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
SOURCEDIR = .
BUILDDIR = _build

# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)

.PHONY: help Makefile

# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
46 changes: 46 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
Assertions
==========

.. currentmodule:: chex

.. autosummary::

assert_devices_available
assert_max_traces
assert_scalar


JaxAssertions
~~~~~~~~~~~~~

.. autofunction:: assert_devices_available
.. autofunction:: assert_max_traces


GenericAssertions
~~~~~~~~~~~~~~~~~

.. autofunction:: assert_scalar


Fakes
=====

.. currentmodule:: chex

.. autosummary::

set_n_cpu_devices
fake_jit


Devices
~~~~~~~

.. autofunction:: set_n_cpu_devices


Transformations
~~~~~~~~~~~~~~~

.. autofunction:: fake_jit
192 changes: 192 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Configuration file for the Sphinx documentation builder."""

# This file only contains a selection of the most common options. For a full
# list see the documentation:
# http://www.sphinx-doc.org/en/master/config

# -- Path setup --------------------------------------------------------------

# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.

# pylint: disable=g-bad-import-order
# pylint: disable=g-import-not-at-top
import inspect
import os
import sys
import typing


def _add_annotations_import(path):
"""Appends a future annotations import to the file at the given path."""
with open(path) as f:
contents = f.read()
if contents.startswith('from __future__ import annotations'):
# If we run sphinx multiple times then we will append the future import
# multiple times too.
return

assert contents.startswith('#'), (path, contents.split('\n')[0])
with open(path, 'w') as f:
# NOTE: This is subtle and not unit tested, we're prefixing the first line
# in each Python file with this future import. It is important to prefix
# not insert a newline such that source code locations are accurate (we link
# to GitHub). The assertion above ensures that the first line in the file is
# a comment so it is safe to prefix it.
f.write('from __future__ import annotations ')
f.write(contents)


def _recursive_add_annotations_import():
for path, _, files in os.walk('../chex/'):
for file in files:
if file.endswith('.py'):
_add_annotations_import(os.path.abspath(os.path.join(path, file)))

if 'READTHEDOCS' in os.environ:
_recursive_add_annotations_import()

typing.get_type_hints = lambda obj, *unused: obj.__annotations__
sys.path.insert(0, os.path.abspath('../'))
sys.path.append(os.path.abspath('ext'))

import chex
import sphinxcontrib.katex as katex

# -- Project information -----------------------------------------------------

project = 'Chex'
copyright = '2021, DeepMind' # pylint: disable=redefined-builtin
author = 'Chex Contributors'

# -- General configuration ---------------------------------------------------

master_doc = 'index'

# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.autosummary',
'sphinx.ext.doctest',
'sphinx.ext.inheritance_diagram',
'sphinx.ext.intersphinx',
'sphinx.ext.linkcode',
'sphinx.ext.napoleon',
'sphinxcontrib.bibtex',
'sphinxcontrib.katex',
'sphinx_autodoc_typehints',
'sphinx_rtd_theme',
'coverage_check',
'myst_nb', # This is used for the .ipynb notebooks
]

# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']

# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']

# -- Options for autodoc -----------------------------------------------------

autodoc_default_options = {
'member-order': 'bysource',
'special-members': True,
'exclude-members': '__repr__, __str__, __weakref__',
}

# -- Options for HTML output -------------------------------------------------

# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'

# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = []
# html_favicon = '_static/favicon.ico'

# -- Options for myst -------------------------------------------------------

jupyter_execute_notebooks = 'force'
execution_allow_errors = False

# -- Options for katex ------------------------------------------------------

# See: https://sphinxcontrib-katex.readthedocs.io/en/0.4.1/macros.html
latex_macros = r"""
\def \d #1{\operatorname{#1}}
"""

# Translate LaTeX macros to KaTeX and add to options for HTML builder
katex_macros = katex.latex_defs_to_katex_macros(latex_macros)
katex_options = 'macros: {' + katex_macros + '}'

# Add LaTeX macros for LATEX builder
latex_elements = {'preamble': latex_macros}

# -- Source code links -------------------------------------------------------


def linkcode_resolve(domain, info):
"""Resolve a GitHub URL corresponding to Python object."""
if domain != 'py':
return None

try:
mod = sys.modules[info['module']]
except ImportError:
return None

obj = mod
try:
for attr in info['fullname'].split('.'):
obj = getattr(obj, attr)
except AttributeError:
return None
else:
obj = inspect.unwrap(obj)

try:
filename = inspect.getsourcefile(obj)
except TypeError:
return None

try:
source, lineno = inspect.getsourcelines(obj)
except OSError:
return None

# TODO(slebedev): support tags after we release an initial version.
return 'https://github.com/deepmind/chex/tree/master/chex/%s#L%d#L%d' % (
os.path.relpath(filename, start=os.path.dirname(
chex.__file__)), lineno, lineno + len(source) - 1)


# -- Intersphinx configuration -----------------------------------------------

intersphinx_mapping = {
'jax': ('https://jax.readthedocs.io/en/latest/', None),
}

source_suffix = ['.rst', '.md', '.ipynb']
Loading