Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(dpmodel): fix precision #4343

Merged
merged 15 commits into from
Nov 14, 2024
1 change: 1 addition & 0 deletions .github/workflows/test_python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ jobs:
env:
NUM_WORKERS: 0
DP_TEST_TF2_ONLY: 1
DP_DTYPE_PROMOTION_STRICT: 1
if: matrix.group == 1
- run: mv .test_durations .test_durations_${{ matrix.group }}
- name: Upload partial durations
Expand Down
13 changes: 7 additions & 6 deletions deepmd/dpmodel/atomic_model/base_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,17 +201,18 @@ def forward_common_atomic(
ret_dict = self.apply_out_stat(ret_dict, atype)

# nf x nloc
atom_mask = ext_atom_mask[:, :nloc].astype(xp.int32)
atom_mask = ext_atom_mask[:, :nloc]
if self.atom_excl is not None:
atom_mask *= self.atom_excl.build_type_exclude_mask(atype)
atom_mask = xp.logical_and(
atom_mask, self.atom_excl.build_type_exclude_mask(atype)
)

for kk in ret_dict.keys():
out_shape = ret_dict[kk].shape
out_shape2 = math.prod(out_shape[2:])
ret_dict[kk] = (
ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
* atom_mask[:, :, None]
).reshape(out_shape)
tmp_arr = ret_dict[kk].reshape([out_shape[0], out_shape[1], out_shape2])
tmp_arr = xp.where(atom_mask[:, :, None], tmp_arr, xp.zeros_like(tmp_arr))
ret_dict[kk] = xp.reshape(tmp_arr, out_shape)
ret_dict["mask"] = atom_mask

return ret_dict
Expand Down
104 changes: 104 additions & 0 deletions deepmd/dpmodel/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@
ABC,
abstractmethod,
)
from functools import (
wraps,
)
from typing import (
Any,
Callable,
Optional,
overload,
)

import array_api_compat
Expand Down Expand Up @@ -116,6 +121,105 @@
return np.from_dlpack(x)


def cast_precision(func: Callable[..., Any]) -> Callable[..., Any]:
"""A decorator that casts and casts back the input
and output tensor of a method.

The decorator should be used on an instance method.

The decorator will do the following thing:
(1) It casts input arrays from the global precision
to precision defined by property `precision`.
(2) It casts output arrays from `precision` to
the global precision.
(3) It checks inputs and outputs and only casts when
input or output is an array and its dtype matches
the global precision and `precision`, respectively.
If it does not match (e.g. it is an integer), the decorator
will do nothing on it.

The decorator supports the array API.

Returns
-------
Callable
a decorator that casts and casts back the input and
output array of a method

Examples
--------
>>> class A:
... def __init__(self):
... self.precision = "float32"
...
... @cast_precision
... def f(x: Array, y: Array) -> Array:
... return x**2 + y
"""

@wraps(func)
def wrapper(self, *args, **kwargs):
# only convert tensors
returned_tensor = func(
self,
*[safe_cast_array(vv, "global", self.precision) for vv in args],
**{
kk: safe_cast_array(vv, "global", self.precision)
for kk, vv in kwargs.items()
},
)
if isinstance(returned_tensor, tuple):
njzjz marked this conversation as resolved.
Show resolved Hide resolved
return tuple(
safe_cast_array(vv, self.precision, "global") for vv in returned_tensor
)
elif isinstance(returned_tensor, dict):
return {
kk: safe_cast_array(vv, self.precision, "global")
for kk, vv in returned_tensor.items()
}
else:
return safe_cast_array(returned_tensor, self.precision, "global")

return wrapper


@overload
def safe_cast_array(
input: np.ndarray, from_precision: str, to_precision: str
njzjz marked this conversation as resolved.
Show resolved Hide resolved
) -> np.ndarray: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
@overload
def safe_cast_array(input: None, from_precision: str, to_precision: str) -> None: ...

Check notice

Code scanning / CodeQL

Statement has no effect Note

This statement has no effect.
def safe_cast_array(
input: Optional[np.ndarray], from_precision: str, to_precision: str
) -> Optional[np.ndarray]:
"""Convert an array from a precision to another precision.

If input is not an array or without the specific precision, the method will not
cast it.

Array API is supported.

Parameters
----------
input : np.ndarray or None
Input array
from_precision : str
Array data type that is casted from
to_precision : str
Array data type that casts to

Returns
-------
np.ndarray or None
casted array
"""
if array_api_compat.is_array_api_obj(input):
xp = array_api_compat.array_namespace(input)
if input.dtype == get_xp_precision(xp, from_precision):
return xp.astype(input, get_xp_precision(xp, to_precision))
return input
njzjz marked this conversation as resolved.
Show resolved Hide resolved


__all__ = [
"GLOBAL_NP_FLOAT_PRECISION",
"GLOBAL_ENER_FLOAT_PRECISION",
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
cast_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand Down Expand Up @@ -329,6 +330,7 @@ def __init__(
self.tebd_dim = tebd_dim
self.concat_output_tebd = concat_output_tebd
self.trainable = trainable
self.precision = precision

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -448,6 +450,7 @@ def change_type_map(
obj["davg"] = obj["davg"][remap_index]
obj["dstd"] = obj["dstd"][remap_index]

@cast_precision
def call(
self,
coord_ext,
Expand Down
3 changes: 3 additions & 0 deletions deepmd/dpmodel/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
cast_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand Down Expand Up @@ -594,6 +595,7 @@ def init_subclass_params(sub_data, sub_class):
self.rcut = self.repinit.get_rcut()
self.ntypes = ntypes
self.sel = self.repinit.sel
self.precision = precision
njzjz marked this conversation as resolved.
Show resolved Hide resolved

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -757,6 +759,7 @@ def get_stat_mean_and_stddev(self) -> tuple[list[np.ndarray], list[np.ndarray]]:
stddev_list.append(self.repinit_three_body.stddev)
return mean_list, stddev_list

@cast_precision
def call(
self,
coord_ext: np.ndarray,
Expand Down
14 changes: 5 additions & 9 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
NativeOP,
)
from deepmd.dpmodel.common import (
cast_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand All @@ -29,9 +30,6 @@
from deepmd.dpmodel.utils.update_sel import (
UpdateSel,
)
from deepmd.env import (
GLOBAL_NP_FLOAT_PRECISION,
)
from deepmd.utils.data_system import (
DeepmdDataSystem,
)
Expand Down Expand Up @@ -340,6 +338,7 @@ def reinit_exclude(
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)

@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -415,9 +414,7 @@ def call(
# nf x nloc x ng x ng1
grrg = np.einsum("flid,fljd->flij", gr, gr1)
# nf x nloc x (ng x ng1)
grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron).astype(
GLOBAL_NP_FLOAT_PRECISION
)
grrg = grrg.reshape(nf, nloc, ng * self.axis_neuron)
return grrg, gr[..., 1:], None, None, ww

def serialize(self) -> dict:
Expand Down Expand Up @@ -506,6 +503,7 @@ def update_sel(


class DescrptSeAArrayAPI(DescrptSeA):
@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -585,7 +583,5 @@ def call(
# grrg = xp.einsum("flid,fljd->flij", gr, gr1)
grrg = xp.sum(gr[:, :, :, None, :] * gr1[:, :, None, :, :], axis=4)
# nf x nloc x (ng x ng1)
grrg = xp.astype(
xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron)), input_dtype
)
grrg = xp.reshape(grrg, (nf, nloc, ng * self.axis_neuron))
return grrg, gr[..., 1:], None, None, ww
3 changes: 2 additions & 1 deletion deepmd/dpmodel/descriptor/se_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
NativeOP,
)
from deepmd.dpmodel.common import (
cast_precision,
get_xp_precision,
to_numpy_array,
)
Expand Down Expand Up @@ -289,6 +290,7 @@ def cal_g(
gg = self.embeddings[(ll,)].call(ss)
return gg

@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -352,7 +354,6 @@ def call(
res_rescale = 1.0 / 5.0
res = xyz_scatter * res_rescale
res = xp.reshape(res, (nf, nloc, ng))
res = xp.astype(res, get_xp_precision(xp, "global"))
return res, None, None, None, ww

def serialize(self) -> dict:
Expand Down
4 changes: 2 additions & 2 deletions deepmd/dpmodel/descriptor/se_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
NativeOP,
)
from deepmd.dpmodel.common import (
cast_precision,
get_xp_precision,
to_numpy_array,
)
Expand Down Expand Up @@ -264,6 +265,7 @@ def reinit_exclude(
self.exclude_types = exclude_types
self.emask = PairExcludeMask(self.ntypes, exclude_types=exclude_types)

@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -317,7 +319,6 @@ def call(
# we don't require atype is the same in all frames
exclude_mask = xp.reshape(exclude_mask, (nf * nloc, nnei))
rr = xp.reshape(rr, (nf * nloc, nnei, 4))
rr = xp.astype(rr, get_xp_precision(xp, self.precision))

for embedding_idx in itertools.product(
range(self.ntypes), repeat=self.embeddings.ndim
Expand Down Expand Up @@ -349,7 +350,6 @@ def call(
result += res_ij
# nf x nloc x ng
result = xp.reshape(result, (nf, nloc, ng))
result = xp.astype(result, get_xp_precision(xp, "global"))
return result, None, None, None, ww

def serialize(self) -> dict:
Expand Down
5 changes: 3 additions & 2 deletions deepmd/dpmodel/descriptor/se_t_tebd.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
xp_take_along_axis,
)
from deepmd.dpmodel.common import (
get_xp_precision,
cast_precision,
to_numpy_array,
)
from deepmd.dpmodel.utils import (
Expand Down Expand Up @@ -168,6 +168,7 @@ def __init__(
self.tebd_dim = tebd_dim
self.concat_output_tebd = concat_output_tebd
self.trainable = trainable
self.precision = precision

def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down Expand Up @@ -287,6 +288,7 @@ def change_type_map(
obj["davg"] = obj["davg"][remap_index]
obj["dstd"] = obj["dstd"][remap_index]

@cast_precision
def call(
self,
coord_ext,
Expand Down Expand Up @@ -741,7 +743,6 @@ def call(
res_ij = res_ij * (1.0 / float(self.nnei) / float(self.nnei))
# nf x nl x ng
result = xp.reshape(res_ij, (nf, nloc, self.filter_neuron[-1]))
result = xp.astype(result, get_xp_precision(xp, "global"))
return (
result,
None,
Expand Down
22 changes: 15 additions & 7 deletions deepmd/dpmodel/fitting/general_fitting.py
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,11 @@ def _call_common(

"""
xp = array_api_compat.array_namespace(descriptor, atype)
descriptor = xp.astype(descriptor, get_xp_precision(xp, self.precision))
njzjz marked this conversation as resolved.
Show resolved Hide resolved
if fparam is not None:
fparam = xp.astype(fparam, get_xp_precision(xp, self.precision))
if aparam is not None:
aparam = xp.astype(aparam, get_xp_precision(xp, self.precision))
nf, nloc, nd = descriptor.shape
net_dim_out = self._net_out_dim()
# check input dim
Expand Down Expand Up @@ -439,18 +444,21 @@ def _call_common(
):
assert xx_zeros is not None
atom_property -= self.nets[(type_i,)](xx_zeros)
atom_property = atom_property + self.bias_atom_e[type_i, ...]
atom_property = atom_property * xp.astype(mask, atom_property.dtype)
atom_property = xp.where(
mask, atom_property, xp.zeros_like(atom_property)
)
outs = outs + atom_property # Shape is [nframes, natoms[0], 1]
else:
outs = self.nets[()](xx) + xp.reshape(
xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0),
[nf, nloc, net_dim_out],
)
outs = self.nets[()](xx)
if xx_zeros is not None:
outs -= self.nets[()](xx_zeros)
outs = xp.astype(outs, get_xp_precision(xp, "global"))
outs += xp.reshape(
xp.take(self.bias_atom_e, xp.reshape(atype, [-1]), axis=0),
[nf, nloc, net_dim_out],
)
# nf x nloc
exclude_mask = self.emask.build_type_exclude_mask(atype)
# nf x nloc x nod
outs = outs * xp.astype(exclude_mask[:, :, None], outs.dtype)
outs = xp.where(exclude_mask[:, :, None], outs, xp.zeros_like(outs))
return {self.var_name: xp.astype(outs, get_xp_precision(xp, "global"))}
3 changes: 3 additions & 0 deletions deepmd/jax/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
jax.config.update("jax_enable_x64", True)
# jax.config.update("jax_debug_nans", True)

if os.environ.get("DP_DTYPE_PROMOTION_STRICT") == "1":
jax.config.update("jax_numpy_dtype_promotion", "strict")

__all__ = [
"jax",
"jnp",
Expand Down
Loading