Skip to content

Commit

Permalink
docs: improve docstrings of function module (#428)
Browse files Browse the repository at this point in the history
* docs: improve docstring create_(parametrized_)function
* docs: remove __eq__ and __call__ from default API
* docs: remove links to General index etc
  • Loading branch information
redeboer authored Apr 11, 2022
1 parent a6406d6 commit ddaf485
Show file tree
Hide file tree
Showing 9 changed files with 96 additions and 21 deletions.
1 change: 1 addition & 0 deletions .cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@
"lambdification",
"lambdified",
"lambdify",
"lambdifygenerated",
"linestyle",
"linewidth",
"linkcheck",
Expand Down
1 change: 1 addition & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ rst-roles =
mod
ref
rst-directives =
automethod
deprecated
envvar
exception
Expand Down
6 changes: 0 additions & 6 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,6 @@ def fetch_logo(url: str, output_path: str) -> None:
"members": True,
"undoc-members": True,
"show-inheritance": True,
"special-members": ", ".join(
[
"__call__",
"__eq__",
]
),
}
autodoc_member_order = "bysource"
autodoc_type_aliases = {
Expand Down
4 changes: 0 additions & 4 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,6 @@ Upcoming features <https://github.com/ComPWA/tensorwaves/milestones?direction=as
Help developing <https://compwa-org.rtfd.io/en/stable/develop.html>
```

- {ref}`Python API <modindex>`
- {ref}`General Index <genindex>`
- {ref}`Search <search>`

```{toctree}
---
caption: Related projects
Expand Down
4 changes: 2 additions & 2 deletions docs/usage/basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The {meth}`.ParametrizedBackendFunction.__call__` takes a {class}`dict` of variable names (here, `\"x\"` only) to the value(s) that should be used in their place."
"The {meth}`.ParametrizedFunction.__call__` takes a {class}`dict` of variable names (here, `\"x\"` only) to the value(s) that should be used in their place."
]
},
{
Expand Down Expand Up @@ -503,7 +503,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"First of all, we need to randomly generate values of $x$. In this simple, 1-dimensional example, we could just use a random generator like {class}`numpy.random.Generator` feed its output to the {meth}`.ParametrizedBackendFunction.__call__`. Generally, though, we want to cover $n$-dimensional cases. The class {class}`.NumpyDomainGenerator` allows us to generate such a **uniform** distribution for each variable within a certain range. It requires a {class}`.RealNumberGenerator` (here we use {class}`.NumpyUniformRNG`) and it also requires us to define boundaries for each variable in the resulting {obj}`.DataSample`."
"First of all, we need to randomly generate values of $x$. In this simple, 1-dimensional example, we could just use a random generator like {class}`numpy.random.Generator` feed its output to the {meth}`.ParametrizedFunction.__call__`. Generally, though, we want to cover $n$-dimensional cases. The class {class}`.NumpyDomainGenerator` allows us to generate such a **uniform** distribution for each variable within a certain range. It requires a {class}`.RealNumberGenerator` (here we use {class}`.NumpyUniformRNG`) and it also requires us to define boundaries for each variable in the resulting {obj}`.DataSample`."
]
},
{
Expand Down
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ addopts =
--ignore=docs/conf.py
-k "not benchmark"
-m "not slow"
doctest_optionflags = NORMALIZE_WHITESPACE
filterwarnings =
error
ignore:.* is deprecated and will be removed in Pillow 10.*:DeprecationWarning
Expand Down
24 changes: 20 additions & 4 deletions src/tensorwaves/function/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,13 @@ def _to_tuple(argument_order: Iterable[str]) -> Tuple[str, ...]:
class PositionalArgumentFunction(Function):
"""Wrapper around a function with positional arguments.
This class provides a :meth:`__call__` that can take a `.DataSample` for a
function with `positional arguments
This class provides a :meth:`~.Function.__call__` that can take a
`.DataSample` for a function with `positional arguments
<https://docs.python.org/3/glossary.html#term-positional-argument>`_. Its
:attr:`argument_order` redirect the keys in the `.DataSample` to the
argument positions in its underlying :attr:`function`.
.. seealso:: :func:`.create_function`
"""

function: Callable[..., np.ndarray] = field(validator=_validate_arguments)
Expand All @@ -84,7 +86,10 @@ def __call__(self, data: DataSample) -> np.ndarray:


class ParametrizedBackendFunction(ParametrizedFunction):
"""Implements `.ParametrizedFunction` for a specific computational back-end."""
"""Implements `.ParametrizedFunction` for a specific computational back-end.
.. seealso:: :func:`.create_parametrized_function`
"""

def __init__(
self,
Expand Down Expand Up @@ -126,7 +131,18 @@ def update_parameters(


def get_source_code(function: Function) -> str:
"""Get the backend source code used to compile this function."""
"""Get the backend source code used to compile this function.
>>> import sympy as sp
>>> from tensorwaves.function.sympy import create_function
>>> x, y = sp.symbols("x y")
>>> expr = x**2 + y**2
>>> func = create_function(expr, backend="jax", use_cse=False)
>>> src = get_source_code(func)
>>> print(src)
def _lambdifygenerated(x, y):
return x**2 + y**2
"""
if isinstance(
function, (PositionalArgumentFunction, ParametrizedBackendFunction)
):
Expand Down
62 changes: 60 additions & 2 deletions src/tensorwaves/function/sympy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,36 @@
def create_function(
expression: "sp.Expr",
backend: str,
max_complexity: Optional[int] = None,
use_cse: bool = True,
max_complexity: Optional[int] = None,
) -> PositionalArgumentFunction:
"""Convert a SymPy expression to a computational function.
Args:
expression: The SymPy expression that you want to
`~sympy.utilities.lambdify.lambdify`. Its
`~sympy.core.basic.Basic.free_symbols` become arguments to the
resulting `.PositionalArgumentFunction`.
backend: The computational backend in which to express the function.
use_cse: Identify common sub-expressions in the function. This usually
makes the function faster and speeds up lambdification.
max_complexity: See :ref:`usage/faster-lambdify:Specifying complexity`
and :doc:`compwa-org:report/002`.
Example:
>>> import numpy as np
>>> import sympy as sp
>>> from tensorwaves.function.sympy import create_function
>>> x, y = sp.symbols("x y")
>>> expression = x**2 + y**2
>>> function = create_function(expression, backend="jax")
>>> array = np.linspace(0, 3, num=4)
>>> data = {"x": array, "y": array}
>>> function(data)
DeviceArray([ 0., 2., 8., 18.], dtype=float64)
"""
free_symbols = _get_free_symbols(expression)
sorted_symbols = sorted(free_symbols, key=lambda s: s.name)
lambdified_function = _lambdify_normal_or_fast(
Expand All @@ -61,9 +88,40 @@ def create_parametrized_function(
expression: "sp.Expr",
parameters: Mapping["sp.Symbol", ParameterValue],
backend: str,
max_complexity: Optional[int] = None,
use_cse: bool = True,
max_complexity: Optional[int] = None,
) -> ParametrizedBackendFunction:
"""Convert a SymPy expression to a parametrized function.
This is an extended version of :func:`create_function`, which allows one to
identify certain symbols in the expression as parameters.
Args:
expression: See :func:`create_function`.
parameters: The symbols in the expression that are be identified as
`~.ParametrizedFunction.parameters` in the returned
`.ParametrizedBackendFunction`.
backend: See :func:`create_function`.
use_cse: See :func:`create_function`.
max_complexity: See :func:`create_function`.
Example:
>>> import numpy as np
>>> import sympy as sp
>>> from tensorwaves.function.sympy import create_parametrized_function
>>> a, b, x, y = sp.symbols("a b x y")
>>> expression = a * x**2 + b * y**2
>>> function = create_parametrized_function(
... expression,
... parameters={a: -1, b: 2.5},
... backend="jax",
... )
>>> array = np.linspace(0, 1, num=5)
>>> data = {"x": array, "y": array}
>>> function.update_parameters({"b": 1})
>>> function(data)
DeviceArray([0., 0., 0., 0., 0.], dtype=float64)
"""
free_symbols = _get_free_symbols(expression)
sorted_symbols = sorted(free_symbols, key=lambda s: s.name)
lambdified_function = _lambdify_normal_or_fast(
Expand Down
14 changes: 11 additions & 3 deletions src/tensorwaves/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class Function(ABC, Generic[InputType, OutputType]):
`.OutputType` values (co-domain) for a given set of `.InputType` values
(domain). Examples of `Function` are `ParametrizedFunction`, `Estimator`
and `DataTransformer`.
.. automethod:: __call__
"""

@abstractmethod
Expand All @@ -55,9 +57,11 @@ class ParametrizedFunction(Function[DataSample, np.ndarray]):
A `ParametrizedFunction` identifies certain variables in a mathematical
expression as **parameters**. Remaining variables are considered **domain
variables**. Domain variables are the argument of the evaluation (see
:func:`~Function.__call__`), while the parameters are controlled via
:attr:`parameters` (getter) and :meth:`update_parameters` (setter). This
mechanism is especially important for an `Estimator`.
:func:`~ParametrizedFunction.__call__`), while the parameters are
controlled via :attr:`parameters` (getter) and :meth:`update_parameters`
(setter). This mechanism is especially important for an `Estimator`.
.. automethod:: __call__
"""

@property
Expand Down Expand Up @@ -85,6 +89,8 @@ class Estimator(Function[Mapping[str, ParameterValue], float]):
See the :mod:`.estimator` module for different implementations of this
interface.
.. automethod:: __call__
"""

def __call__(self, parameters: Mapping[str, ParameterValue]) -> float:
Expand Down Expand Up @@ -210,6 +216,8 @@ class RealNumberGenerator(ABC):
"""Abstract class for generating real numbers within a certain range.
Implementations can be found in the `tensorwaves.data` module.
.. automethod:: __call__
"""

@abstractmethod
Expand Down

0 comments on commit ddaf485

Please sign in to comment.