Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support array API #3922

Merged
merged 3 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions deepmd/dpmodel/array_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Utilities for the array API."""


def support_array_api(version: str) -> callable:
"""Mark a function as supporting the specific version of the array API.

Parameters
----------
version : str
The version of the array API

Returns
-------
callable
The decorated function

Examples
--------
>>> @support_array_api(version="2022.12")
... def f(x):
... pass
"""

def set_version(func: callable) -> callable:
func.array_api_version = version
return func

return set_version
12 changes: 10 additions & 2 deletions deepmd/dpmodel/utils/env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
Union,
)

import array_api_compat
import numpy as np

from deepmd.dpmodel import (
NativeOP,
)
from deepmd.dpmodel.array_api import (
support_array_api,
)


@support_array_api(version="2022.12")
def compute_smooth_weight(
distance: np.ndarray,
rmin: float,
Expand All @@ -19,12 +24,15 @@ def compute_smooth_weight(
"""Compute smooth weight for descriptor elements."""
if rmin >= rmax:
raise ValueError("rmin should be less than rmax.")
xp = array_api_compat.array_namespace(distance)
min_mask = distance <= rmin
max_mask = distance >= rmax
mid_mask = np.logical_not(np.logical_or(min_mask, max_mask))
mid_mask = xp.logical_not(xp.logical_or(min_mask, max_mask))
uu = (distance - rmin) / (rmax - rmin)
vv = uu * uu * uu * (-6.0 * uu * uu + 15.0 * uu - 10.0) + 1.0
return vv * mid_mask + min_mask
return vv * xp.astype(mid_mask, distance.dtype) + xp.astype(
min_mask, distance.dtype
)


def _make_env_mat(
Expand Down
2 changes: 2 additions & 0 deletions doc/backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ As a reference backend, it is not aimed at the best performance, but only the co
The DP backend uses [HDF5](https://docs.h5py.org/) to store model serialization data, which is backend-independent.
Only Python inference interface can load this format.

NumPy 1.21 or above is required.

## Switch the backend

### Training
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ dependencies = [
'packaging',
'ml_dtypes',
'mendeleev',
'array-api-compat',
]
requires-python = ">=3.8"
keywords = ["deepmd"]
Expand Down Expand Up @@ -79,6 +80,7 @@ test = [
"pytest-sugar",
"pytest-split",
"dpgui",
'array-api-strict>=2;python_version>="3.9"',
]
docs = [
"sphinx>=3.1.1",
Expand Down
2 changes: 2 additions & 0 deletions source/tests/common/dpmodel/array_api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
"""Test array API compatibility to be completely sure their usage of the array API is portable."""
30 changes: 30 additions & 0 deletions source/tests/common/dpmodel/array_api/test_env_mat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import sys
import unittest

if sys.version_info >= (3, 9):
import array_api_strict as xp
else:
raise unittest.SkipTest("array_api_strict doesn't support Python<=3.8")

from deepmd.dpmodel.utils.env_mat import (
compute_smooth_weight,
)

from .utils import (
ArrayAPITest,
)


class TestEnvMat(unittest.TestCase, ArrayAPITest):
def test_compute_smooth_weight(self):
self.set_array_api_version(compute_smooth_weight)
d = xp.arange(10, dtype=xp.float64)
w = compute_smooth_weight(
d,
4.0,
6.0,
)
self.assert_namespace_equal(w, d)
self.assert_device_equal(w, d)
self.assert_dtype_equal(w, d)
27 changes: 27 additions & 0 deletions source/tests/common/dpmodel/array_api/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import array_api_compat
from array_api_strict import (
set_array_api_strict_flags,
)


class ArrayAPITest:
"""Utils for array API tests."""

def set_array_api_version(self, func):
"""Set the array API version for a function."""
set_array_api_strict_flags(api_version=func.array_api_version)

def assert_namespace_equal(self, a, b):
"""Assert two array has the same namespace."""
self.assertEqual(
array_api_compat.array_namespace(a), array_api_compat.array_namespace(b)
)

def assert_dtype_equal(self, a, b):
"""Assert two array has the same dtype."""
self.assertEqual(a.dtype, b.dtype)

def assert_device_equal(self, a, b):
"""Assert two array has the same device."""
self.assertEqual(array_api_compat.device(a), array_api_compat.device(b))