diff --git a/CHANGELOG.md b/CHANGELOG.md index 9e521ca..7f4d421 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -23,11 +23,16 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): ## [Unreleased] + + +## [1.0.2] - 2024-04-05 + * Added e3x.ops.normalize_and_return_norm * Added option to include mapping as weighting in mapped functions * Added option to return vector norms to e3x.nn.basis * Added e3x.nn.ExponentialBasis (wraps e3x.nn.basis, injecting learnable gamma) * Added e3x.ops.inverse_softplus helper function +* Added e3x.so3.delley_quadrature for computing Delley quadratures of S2 ## [1.0.1] - 2024-01-17 @@ -37,6 +42,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`): * Initial release -[Unreleased]: https://github.com/google-research/e3x/compare/v1.0.1...HEAD +[Unreleased]: https://github.com/google-research/e3x/compare/v1.0.2...HEAD +[1.0.2]: https://github.com/google-research/e3x/releases/tag/v1.0.2 [1.0.1]: https://github.com/google-research/e3x/releases/tag/v1.0.1 [1.0.0]: https://github.com/google-research/e3x/releases/tag/v1.0.0 diff --git a/e3x/so3/__init__.py b/e3x/so3/__init__.py index c9f8384..63d336a 100755 --- a/e3x/so3/__init__.py +++ b/e3x/so3/__init__.py @@ -23,6 +23,7 @@ from .irreps import solid_harmonics from .irreps import spherical_harmonics from .irreps import tensor_to_irreps +from .quadrature import delley_quadrature from .quadrature import lebedev_quadrature from .rotations import alignment_rotation from .rotations import euler_angles_from_rotation diff --git a/e3x/so3/_delley_grids.npz b/e3x/so3/_delley_grids.npz new file mode 100644 index 0000000..01c72cb Binary files /dev/null and b/e3x/so3/_delley_grids.npz differ diff --git a/e3x/so3/quadrature.py b/e3x/so3/quadrature.py index fbb9b63..df0ec3a 100644 --- a/e3x/so3/quadrature.py +++ b/e3x/so3/quadrature.py @@ -16,7 +16,7 @@ import io import pkgutil -from typing import Optional +from typing import Literal, Optional import zipfile from absl import logging import jax.numpy as jnp @@ -27,60 +27,37 @@ Float = jaxtyping.Float -def lebedev_quadrature( - precision: Optional[int] = None, num: Optional[int] = None +def _load_grid( + kind: Literal['Lebedev', 'Delley'], + precision: Optional[int] = None, + num: Optional[int] = None, ) -> tuple[Float[Array, 'num_points 3'], Float[Array, 'num_points']]: - r"""Returns a Lebedev quadrature grid. - - A Lebedev quadrature is a numerical approximation to the surface integral of - a function :math:`f` over the unit sphere - - .. math:: - \int_{\Omega} f(\Omega) d\Omega \approx 4\pi \sum_{i=1}^N w_i f(\vec{r}_i) - - where :math:`\vec{r}_i` and :math:`w_i` are grid points and weights, - respectively. A Lebedev rule of precision :math:`p` can be used to correctly - integrate any polynomial for which the highest degree term :math:`x^iy^jz^k` - satisfies :math:`i+j+k \leq p`. - - Args: - precision: The desired minimum precision :math:`p`. The returned Lebedev - rule will be the smallest grid that has the desired precision (if this - value is not ``None``, ``num`` must be set to ``None``). - num: The desired maximum number of points :math:`N`. The returned Lebedev - rule will be the highest-precision grid that has not more than ``num`` - points (if this value is not ``None``, ``precision`` must be set to - ``None``). - - Returns: - A tuple of two arrays representing the grid points :math:`\vec{r}_i` and - weights :math:`w_i`. - - Raises: - ValueError: - If both ``precision`` and ``num`` are ``None`` (or both are specified at - the same time). - RuntimeError: - If the Lebedev quadrature rules cannot be loaded from disk. - """ + """Loads a quadrature grid from disk.""" + if kind not in ('Lebedev', 'Delley'): + raise ValueError( + f"Quadrature grid with {kind=} does not exist, choose from ('Lebedev'," + " 'Delley')." + ) if (precision is None) == (num is None): raise ValueError( f'Exactly one of {precision=} or {num=} must be specified.' ) try: - f = io.BytesIO(pkgutil.get_data(__name__, '_lebedev_grids.npz')) - with np.load(f) as lebedev_grids: + f = io.BytesIO(pkgutil.get_data(__name__, f'_{kind.lower()}_grids.npz')) + with np.load(f) as quadrature_grids: if precision is not None: - available_precisions = lebedev_grids['precision'] + available_precisions = quadrature_grids['precision'] if precision > np.max(available_precisions): i = available_precisions.size - 1 logging.warning( ( - 'A Lebedev rule with precision=%s is not available, returning' - ' Lebedev rule with highest available precision=%s instead.' + 'A %s rule with precision=%s is not available, returning' + ' %s rule with highest available precision=%s instead.' ), + kind, precision, + kind, available_precisions[i], ) else: @@ -89,27 +66,111 @@ def lebedev_quadrature( np.where(precision_difference < 0, np.nan, precision_difference) ) else: # num is not None. - available_nums = lebedev_grids['num'] + available_nums = quadrature_grids['num'] if num < np.min(available_nums): i = 0 logging.warning( ( - 'A Lebedev rule with num=%s points is not available,' - ' returning Lebedev rule with the lowest available number of' + 'A %s rule with num=%s points is not available,' + ' returning %s rule with the lowest available number of' ' points num=%s instead.' ), + kind, num, + kind, available_nums[i], ) else: num_difference = max(num, 0) - available_nums i = np.nanargmin(np.where(num_difference < 0, np.nan, num_difference)) - r, w = lebedev_grids[f'r{i}'], lebedev_grids[f'w{i}'] + return ( + jnp.asarray(quadrature_grids[f'r{i}']), + jnp.asarray(quadrature_grids[f'w{i}']), + ) except (zipfile.BadZipFile, OSError, IOError, KeyError, ValueError) as exc: raise RuntimeError( - f'failed to load Lebedev quadrature grids included with the {__name__}' + f'failed to load {kind} quadrature grids included with the {__name__}' ' package (data may be corrupted), consider re-installing the package' ' to fix this problem' ) from exc - return jnp.asarray(r), jnp.asarray(w) + +def lebedev_quadrature( + precision: Optional[int] = None, num: Optional[int] = None +) -> tuple[Float[Array, 'num_points 3'], Float[Array, 'num_points']]: + r"""Returns a Lebedev quadrature grid. + + A Lebedev quadrature is a numerical approximation to the surface integral of + a function :math:`f` over the unit sphere + + .. math:: + \int_{\Omega} f(\Omega) d\Omega \approx 4\pi \sum_{i=1}^N w_i f(\vec{r}_i) + + where :math:`\vec{r}_i` and :math:`w_i` are grid points and weights, + respectively. A Lebedev rule of precision :math:`p` can be used to correctly + integrate any polynomial for which the highest degree term :math:`x^iy^jz^k` + satisfies :math:`i+j+k \leq p`. + + Args: + precision: The desired minimum precision :math:`p`. The returned Lebedev + rule will be the smallest grid that has the desired precision (if this + value is not ``None``, ``num`` must be set to ``None``). + num: The desired maximum number of points :math:`N`. The returned Lebedev + rule will be the highest-precision grid that has not more than ``num`` + points (if this value is not ``None``, ``precision`` must be set to + ``None``). + + Returns: + A tuple of two arrays representing the grid points :math:`\vec{r}_i` and + weights :math:`w_i`. + + Raises: + ValueError: + If both ``precision`` and ``num`` are ``None`` (or both are specified at + the same time). + RuntimeError: + If the Lebedev quadrature rules cannot be loaded from disk. + """ + return _load_grid(kind='Lebedev', precision=precision, num=num) + + +def delley_quadrature( + precision: Optional[int] = None, num: Optional[int] = None +) -> tuple[Float[Array, 'num_points 3'], Float[Array, 'num_points']]: + r"""Returns a Delley quadrature grid. + + A Delley quadrature is a numerical approximation to the surface integral of + a function :math:`f` over the unit sphere + + .. math:: + \int_{\Omega} f(\Omega) d\Omega \approx 4\pi \sum_{i=1}^N w_i f(\vec{r}_i) + + where :math:`\vec{r}_i` and :math:`w_i` are grid points and weights, + respectively. A Delley rule of precision :math:`p` can be used to correctly + integrate any polynomial for which the highest degree term :math:`x^iy^jz^k` + satisfies :math:`i+j+k \leq p`. The Delley rules are an optimized version of + the Lebedev rules with improved numerical precision, for details, see + Delley, Bernard. "High order integration schemes on the unit sphere." + Journal of Computational Chemistry 17.9 (1996): 1152-1155. + + Args: + precision: The desired minimum precision :math:`p`. The returned Delley rule + will be the smallest grid that has the desired precision (if this value is + not ``None``, ``num`` must be set to ``None``). + num: The desired maximum number of points :math:`N`. The returned Delley + rule will be the highest-precision grid that has not more than ``num`` + points (if this value is not ``None``, ``precision`` must be set to + ``None``). + + Returns: + A tuple of two arrays representing the grid points :math:`\vec{r}_i` and + weights :math:`w_i`. + + Raises: + ValueError: + If both ``precision`` and ``num`` are ``None`` (or both are specified at + the same time). + RuntimeError: + If the Delley quadrature rules cannot be loaded from disk. + """ + return _load_grid(kind='Delley', precision=precision, num=num) diff --git a/e3x/version.py b/e3x/version.py index 82d3d0b..ec17dc2 100644 --- a/e3x/version.py +++ b/e3x/version.py @@ -15,5 +15,5 @@ """e3x version.""" -__version__ = '1.0.1' +__version__ = '1.0.2' diff --git a/tests/so3/quadrature_test.py b/tests/so3/quadrature_test.py index 26d8e3a..a668d4e 100644 --- a/tests/so3/quadrature_test.py +++ b/tests/so3/quadrature_test.py @@ -12,11 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Callable, Optional import e3x import jax.numpy as jnp +import jaxtyping import pytest +Array = jaxtyping.Array +Float = jaxtyping.Float + # For reference of available Lebedev rules, see: # https://people.sc.fsu.edu/~jburkardt/datasets/sphere_lebedev_rule/sphere_lebedev_rule.html @@ -49,19 +53,35 @@ (None, 10000, 5810), ], ) -def test_lebedev_quadrature_loading( +def test__load_grid( precision: Optional[int], num: Optional[int], expected_size: int ) -> None: - _, w = e3x.so3.lebedev_quadrature(precision=precision, num=num) + _, w = e3x.so3.quadrature._load_grid( + kind='Lebedev', precision=precision, num=num + ) assert w.size == expected_size +def test__load_grid_raises_with_invalid_kind() -> None: + with pytest.raises(ValueError, match="kind='Foo' does not exist"): + e3x.so3.quadrature._load_grid(kind='Foo') # type: ignore + + +@pytest.mark.parametrize( + 'quadrature', [e3x.so3.lebedev_quadrature, e3x.so3.delley_quadrature] +) @pytest.mark.parametrize('precision', list(range(0, 132))) -def test_lebedev_quadrature(precision: int) -> None: +def test_quadrature( + quadrature: Callable[ + [Optional[int], Optional[int]], + tuple[Float[Array, 'num_points 3'], Float[Array, 'num_points']], + ], + precision: int, +) -> None: # The spherical harmonics are orthonormalized, so we can test whether the # quadrature rules work as expected by checking their orthonormalization. - r, w = e3x.so3.lebedev_quadrature(precision=precision) + r, w = quadrature(precision, None) ylm = e3x.so3.spherical_harmonics( r, max_degree=min(precision // 2, 15), @@ -72,6 +92,9 @@ def test_lebedev_quadrature(precision: int) -> None: assert jnp.allclose(eye, jnp.eye(ylm.shape[-1]), atol=1e-5) +@pytest.mark.parametrize( + 'quadrature', [e3x.so3.lebedev_quadrature, e3x.so3.delley_quadrature] +) @pytest.mark.parametrize( 'precision, num, message', [ @@ -87,8 +110,14 @@ def test_lebedev_quadrature(precision: int) -> None: ), ], ) -def test_lebedev_quadrature_raises_with_invalid_inputs( - precision: Optional[int], num: Optional[int], message: str +def test_quadrature_raises_with_invalid_inputs( + quadrature: Callable[ + [Optional[int], Optional[int]], + tuple[Float[Array, 'num_points 3'], Float[Array, 'num_points']], + ], + precision: Optional[int], + num: Optional[int], + message: str, ) -> None: with pytest.raises(ValueError, match=message): - e3x.so3.lebedev_quadrature(precision=precision, num=num) + quadrature(precision, num)