Skip to content

Commit

Permalink
Add Delley quadrature
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 622114438
  • Loading branch information
OUnke authored and The e3x Authors committed Apr 5, 2024
1 parent 96df74f commit ab86199
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 57 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,16 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

## [Unreleased]



## [1.0.2] - 2024-04-05

* Added e3x.ops.normalize_and_return_norm
* Added option to include mapping as weighting in mapped functions
* Added option to return vector norms to e3x.nn.basis
* Added e3x.nn.ExponentialBasis (wraps e3x.nn.basis, injecting learnable gamma)
* Added e3x.ops.inverse_softplus helper function
* Added e3x.so3.delley_quadrature for computing Delley quadratures of S2

## [1.0.1] - 2024-01-17

Expand All @@ -37,6 +42,7 @@ To release a new version (e.g. from `1.0.0` -> `2.0.0`):

* Initial release

[Unreleased]: https://github.com/google-research/e3x/compare/v1.0.1...HEAD
[Unreleased]: https://github.com/google-research/e3x/compare/v1.0.2...HEAD
[1.0.2]: https://github.com/google-research/e3x/releases/tag/v1.0.2
[1.0.1]: https://github.com/google-research/e3x/releases/tag/v1.0.1
[1.0.0]: https://github.com/google-research/e3x/releases/tag/v1.0.0
1 change: 1 addition & 0 deletions e3x/so3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .irreps import solid_harmonics
from .irreps import spherical_harmonics
from .irreps import tensor_to_irreps
from .quadrature import delley_quadrature
from .quadrature import lebedev_quadrature
from .rotations import alignment_rotation
from .rotations import euler_angles_from_rotation
Expand Down
Binary file added e3x/so3/_delley_grids.npz
Binary file not shown.
155 changes: 108 additions & 47 deletions e3x/so3/quadrature.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import io
import pkgutil
from typing import Optional
from typing import Literal, Optional
import zipfile
from absl import logging
import jax.numpy as jnp
Expand All @@ -27,60 +27,37 @@
Float = jaxtyping.Float


def lebedev_quadrature(
precision: Optional[int] = None, num: Optional[int] = None
def _load_grid(
kind: Literal['Lebedev', 'Delley'],
precision: Optional[int] = None,
num: Optional[int] = None,
) -> tuple[Float[Array, 'num_points 3'], Float[Array, 'num_points']]:
r"""Returns a Lebedev quadrature grid.
A Lebedev quadrature is a numerical approximation to the surface integral of
a function :math:`f` over the unit sphere
.. math::
\int_{\Omega} f(\Omega) d\Omega \approx 4\pi \sum_{i=1}^N w_i f(\vec{r}_i)
where :math:`\vec{r}_i` and :math:`w_i` are grid points and weights,
respectively. A Lebedev rule of precision :math:`p` can be used to correctly
integrate any polynomial for which the highest degree term :math:`x^iy^jz^k`
satisfies :math:`i+j+k \leq p`.
Args:
precision: The desired minimum precision :math:`p`. The returned Lebedev
rule will be the smallest grid that has the desired precision (if this
value is not ``None``, ``num`` must be set to ``None``).
num: The desired maximum number of points :math:`N`. The returned Lebedev
rule will be the highest-precision grid that has not more than ``num``
points (if this value is not ``None``, ``precision`` must be set to
``None``).
Returns:
A tuple of two arrays representing the grid points :math:`\vec{r}_i` and
weights :math:`w_i`.
Raises:
ValueError:
If both ``precision`` and ``num`` are ``None`` (or both are specified at
the same time).
RuntimeError:
If the Lebedev quadrature rules cannot be loaded from disk.
"""
"""Loads a quadrature grid from disk."""
if kind not in ('Lebedev', 'Delley'):
raise ValueError(
f"Quadrature grid with {kind=} does not exist, choose from ('Lebedev',"
" 'Delley')."
)
if (precision is None) == (num is None):
raise ValueError(
f'Exactly one of {precision=} or {num=} must be specified.'
)

try:
f = io.BytesIO(pkgutil.get_data(__name__, '_lebedev_grids.npz'))
with np.load(f) as lebedev_grids:
f = io.BytesIO(pkgutil.get_data(__name__, f'_{kind.lower()}_grids.npz'))
with np.load(f) as quadrature_grids:
if precision is not None:
available_precisions = lebedev_grids['precision']
available_precisions = quadrature_grids['precision']
if precision > np.max(available_precisions):
i = available_precisions.size - 1
logging.warning(
(
'A Lebedev rule with precision=%s is not available, returning'
' Lebedev rule with highest available precision=%s instead.'
'A %s rule with precision=%s is not available, returning'
' %s rule with highest available precision=%s instead.'
),
kind,
precision,
kind,
available_precisions[i],
)
else:
Expand All @@ -89,27 +66,111 @@ def lebedev_quadrature(
np.where(precision_difference < 0, np.nan, precision_difference)
)
else: # num is not None.
available_nums = lebedev_grids['num']
available_nums = quadrature_grids['num']
if num < np.min(available_nums):
i = 0
logging.warning(
(
'A Lebedev rule with num=%s points is not available,'
' returning Lebedev rule with the lowest available number of'
'A %s rule with num=%s points is not available,'
' returning %s rule with the lowest available number of'
' points num=%s instead.'
),
kind,
num,
kind,
available_nums[i],
)
else:
num_difference = max(num, 0) - available_nums
i = np.nanargmin(np.where(num_difference < 0, np.nan, num_difference))
r, w = lebedev_grids[f'r{i}'], lebedev_grids[f'w{i}']
return (
jnp.asarray(quadrature_grids[f'r{i}']),
jnp.asarray(quadrature_grids[f'w{i}']),
)
except (zipfile.BadZipFile, OSError, IOError, KeyError, ValueError) as exc:
raise RuntimeError(
f'failed to load Lebedev quadrature grids included with the {__name__}'
f'failed to load {kind} quadrature grids included with the {__name__}'
' package (data may be corrupted), consider re-installing the package'
' to fix this problem'
) from exc

return jnp.asarray(r), jnp.asarray(w)

def lebedev_quadrature(
precision: Optional[int] = None, num: Optional[int] = None
) -> tuple[Float[Array, 'num_points 3'], Float[Array, 'num_points']]:
r"""Returns a Lebedev quadrature grid.
A Lebedev quadrature is a numerical approximation to the surface integral of
a function :math:`f` over the unit sphere
.. math::
\int_{\Omega} f(\Omega) d\Omega \approx 4\pi \sum_{i=1}^N w_i f(\vec{r}_i)
where :math:`\vec{r}_i` and :math:`w_i` are grid points and weights,
respectively. A Lebedev rule of precision :math:`p` can be used to correctly
integrate any polynomial for which the highest degree term :math:`x^iy^jz^k`
satisfies :math:`i+j+k \leq p`.
Args:
precision: The desired minimum precision :math:`p`. The returned Lebedev
rule will be the smallest grid that has the desired precision (if this
value is not ``None``, ``num`` must be set to ``None``).
num: The desired maximum number of points :math:`N`. The returned Lebedev
rule will be the highest-precision grid that has not more than ``num``
points (if this value is not ``None``, ``precision`` must be set to
``None``).
Returns:
A tuple of two arrays representing the grid points :math:`\vec{r}_i` and
weights :math:`w_i`.
Raises:
ValueError:
If both ``precision`` and ``num`` are ``None`` (or both are specified at
the same time).
RuntimeError:
If the Lebedev quadrature rules cannot be loaded from disk.
"""
return _load_grid(kind='Lebedev', precision=precision, num=num)


def delley_quadrature(
precision: Optional[int] = None, num: Optional[int] = None
) -> tuple[Float[Array, 'num_points 3'], Float[Array, 'num_points']]:
r"""Returns a Delley quadrature grid.
A Delley quadrature is a numerical approximation to the surface integral of
a function :math:`f` over the unit sphere
.. math::
\int_{\Omega} f(\Omega) d\Omega \approx 4\pi \sum_{i=1}^N w_i f(\vec{r}_i)
where :math:`\vec{r}_i` and :math:`w_i` are grid points and weights,
respectively. A Delley rule of precision :math:`p` can be used to correctly
integrate any polynomial for which the highest degree term :math:`x^iy^jz^k`
satisfies :math:`i+j+k \leq p`. The Delley rules are an optimized version of
the Lebedev rules with improved numerical precision, for details, see
Delley, Bernard. "High order integration schemes on the unit sphere."
Journal of Computational Chemistry 17.9 (1996): 1152-1155.
Args:
precision: The desired minimum precision :math:`p`. The returned Delley rule
will be the smallest grid that has the desired precision (if this value is
not ``None``, ``num`` must be set to ``None``).
num: The desired maximum number of points :math:`N`. The returned Delley
rule will be the highest-precision grid that has not more than ``num``
points (if this value is not ``None``, ``precision`` must be set to
``None``).
Returns:
A tuple of two arrays representing the grid points :math:`\vec{r}_i` and
weights :math:`w_i`.
Raises:
ValueError:
If both ``precision`` and ``num`` are ``None`` (or both are specified at
the same time).
RuntimeError:
If the Delley quadrature rules cannot be loaded from disk.
"""
return _load_grid(kind='Delley', precision=precision, num=num)
2 changes: 1 addition & 1 deletion e3x/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,5 @@
"""e3x version."""


__version__ = '1.0.1'
__version__ = '1.0.2'

45 changes: 37 additions & 8 deletions tests/so3/quadrature_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import Callable, Optional
import e3x
import jax.numpy as jnp
import jaxtyping
import pytest

Array = jaxtyping.Array
Float = jaxtyping.Float


# For reference of available Lebedev rules, see:
# https://people.sc.fsu.edu/~jburkardt/datasets/sphere_lebedev_rule/sphere_lebedev_rule.html
Expand Down Expand Up @@ -49,19 +53,35 @@
(None, 10000, 5810),
],
)
def test_lebedev_quadrature_loading(
def test__load_grid(
precision: Optional[int], num: Optional[int], expected_size: int
) -> None:
_, w = e3x.so3.lebedev_quadrature(precision=precision, num=num)
_, w = e3x.so3.quadrature._load_grid(
kind='Lebedev', precision=precision, num=num
)
assert w.size == expected_size


def test__load_grid_raises_with_invalid_kind() -> None:
with pytest.raises(ValueError, match="kind='Foo' does not exist"):
e3x.so3.quadrature._load_grid(kind='Foo') # type: ignore


@pytest.mark.parametrize(
'quadrature', [e3x.so3.lebedev_quadrature, e3x.so3.delley_quadrature]
)
@pytest.mark.parametrize('precision', list(range(0, 132)))
def test_lebedev_quadrature(precision: int) -> None:
def test_quadrature(
quadrature: Callable[
[Optional[int], Optional[int]],
tuple[Float[Array, 'num_points 3'], Float[Array, 'num_points']],
],
precision: int,
) -> None:
# The spherical harmonics are orthonormalized, so we can test whether the
# quadrature rules work as expected by checking their orthonormalization.

r, w = e3x.so3.lebedev_quadrature(precision=precision)
r, w = quadrature(precision, None)
ylm = e3x.so3.spherical_harmonics(
r,
max_degree=min(precision // 2, 15),
Expand All @@ -72,6 +92,9 @@ def test_lebedev_quadrature(precision: int) -> None:
assert jnp.allclose(eye, jnp.eye(ylm.shape[-1]), atol=1e-5)


@pytest.mark.parametrize(
'quadrature', [e3x.so3.lebedev_quadrature, e3x.so3.delley_quadrature]
)
@pytest.mark.parametrize(
'precision, num, message',
[
Expand All @@ -87,8 +110,14 @@ def test_lebedev_quadrature(precision: int) -> None:
),
],
)
def test_lebedev_quadrature_raises_with_invalid_inputs(
precision: Optional[int], num: Optional[int], message: str
def test_quadrature_raises_with_invalid_inputs(
quadrature: Callable[
[Optional[int], Optional[int]],
tuple[Float[Array, 'num_points 3'], Float[Array, 'num_points']],
],
precision: Optional[int],
num: Optional[int],
message: str,
) -> None:
with pytest.raises(ValueError, match=message):
e3x.so3.lebedev_quadrature(precision=precision, num=num)
quadrature(precision, num)

0 comments on commit ab86199

Please sign in to comment.