Skip to content

Commit

Permalink
feat: dp and pt: implement exclude types in descriptor se_a (#3280)
Browse files Browse the repository at this point in the history
Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
wanghan-iapcm and Han Wang authored Feb 16, 2024
1 parent 15f8d25 commit 8f91aea
Show file tree
Hide file tree
Showing 12 changed files with 275 additions and 20 deletions.
78 changes: 78 additions & 0 deletions deepmd/dpmodel/descriptor/exclude_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
Tuple,
)

import numpy as np


class ExcludeMask:
"""Computes the atom type exclusion mask."""

def __init__(
self,
ntypes: int,
exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
self.ntypes = ntypes
self.exclude_types = set()
for tt in exclude_types:
assert len(tt) == 2
self.exclude_types.add((tt[0], tt[1]))
self.exclude_types.add((tt[1], tt[0]))
# ntypes + 1 for nlist masks
self.type_mask = np.array(
[
[
1 if (tt_i, tt_j) not in self.exclude_types else 0
for tt_i in range(ntypes + 1)
]
for tt_j in range(ntypes + 1)
],
dtype=np.int32,
)
# (ntypes+1 x ntypes+1)
self.type_mask = self.type_mask.reshape([-1])

def build_type_exclude_mask(
self,
nlist: np.ndarray,
atype_ext: np.ndarray,
):
"""Compute type exclusion mask.
Parameters
----------
nlist
The neighbor list. shape: nf x nloc x nnei
atype_ext
The extended aotm types. shape: nf x nall
Returns
-------
mask
The type exclusion mask of shape: nf x nloc x nnei.
Element [ff,ii,jj] being 0 if type(ii), type(nlist[ff,ii,jj]) is excluded,
otherwise being 1.
"""
if len(self.exclude_types) == 0:
# safely return 1 if nothing is excluded.
return np.ones_like(nlist, dtype=np.int32)
nf, nloc, nnei = nlist.shape
nall = atype_ext.shape[1]
# add virtual atom of type ntypes. nf x nall+1
ae = np.concatenate(
[atype_ext, self.ntypes * np.ones([nf, 1], dtype=atype_ext.dtype)], axis=-1
)
type_i = atype_ext[:, :nloc].reshape(nf, nloc) * (self.ntypes + 1)
# nf x nloc x nnei
index = np.where(nlist == -1, nall, nlist).reshape(nf, nloc * nnei)
type_j = np.take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei)
type_ij = type_i[:, :, None] + type_j
# nf x (nloc x nnei)
type_ij = type_ij.reshape(nf, nloc * nnei)
mask = self.type_mask[type_ij].reshape(nf, nloc, nnei)
return mask
9 changes: 7 additions & 2 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from .base_descriptor import (
BaseDescriptor,
)
from .exclude_mask import (
ExcludeMask,
)


class DescrptSeA(NativeOP, BaseDescriptor):
Expand Down Expand Up @@ -140,8 +143,6 @@ def __init__(
## seed, uniform_seed, multi_task, not included.
if not type_one_side:
raise NotImplementedError("type_one_side == False not implemented")
if exclude_types != []:
raise NotImplementedError("exclude_types is not implemented")
if spin is not None:
raise NotImplementedError("spin is not implemented")

Expand All @@ -159,6 +160,7 @@ def __init__(
self.activation_function = activation_function
self.precision = precision
self.spin = spin
self.emask = ExcludeMask(self.ntypes, self.exclude_types)

in_dim = 1 # not considiering type embedding
self.embeddings = NetworkCollection(
Expand Down Expand Up @@ -292,8 +294,11 @@ def call(

ng = self.neuron[-1]
gr = np.zeros([nf, nloc, ng, 4])
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
for tt in range(self.ntypes):
mm = exclude_mask[:, :, sec[tt] : sec[tt + 1]]
tr = rr[:, :, sec[tt] : sec[tt + 1], :]
tr = tr * mm[:, :, :, None]
ss = tr[..., 0:1]
gg = self.cal_g(ss, tt)
# nf x nloc x ng x 4
Expand Down
80 changes: 79 additions & 1 deletion deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
Callable,
List,
Optional,
Set,
Tuple,
Union,
)

Expand All @@ -20,6 +22,9 @@
from deepmd.pt.utils.plugin import (
Plugin,
)
from deepmd.pt.utils.utils import (
to_torch_tensor,
)

from .base_descriptor import (
BaseDescriptor,
Expand Down Expand Up @@ -206,6 +211,32 @@ class DescriptorBlock(torch.nn.Module, ABC):
__plugins = Plugin()
local_cluster = False

def __init__(
self,
ntypes: int,
exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
_exclude_types: Set[Tuple[int, int]] = set()
for tt in exclude_types:
assert len(tt) == 2
_exclude_types.add((tt[0], tt[1]))
_exclude_types.add((tt[1], tt[0]))
# ntypes + 1 for nlist masks
self.type_mask = np.array(
[
[
1 if (tt_i, tt_j) not in _exclude_types else 0
for tt_i in range(ntypes + 1)
]
for tt_j in range(ntypes + 1)
],
dtype=np.int32,
)
# (ntypes+1 x ntypes+1)
self.type_mask = to_torch_tensor(self.type_mask).view([-1])
self.no_exclusion = len(_exclude_types) == 0

@staticmethod
def register(key: str) -> Callable:
"""Register a DescriptorBlock plugin.
Expand Down Expand Up @@ -332,7 +363,54 @@ def forward(
mapping: Optional[torch.Tensor] = None,
):
"""Calculate DescriptorBlock."""
raise NotImplementedError
pass

# may have a better place for this method...
def build_type_exclude_mask(
self,
nlist: torch.Tensor,
atype_ext: torch.Tensor,
) -> torch.Tensor:
"""Compute type exclusion mask.
Parameters
----------
nlist
The neighbor list. shape: nf x nloc x nnei
atype_ext
The extended aotm types. shape: nf x nall
Returns
-------
mask
The type exclusion mask of shape: nf x nloc x nnei.
Element [ff,ii,jj] being 0 if type(ii), type(nlist[ff,ii,jj]) is excluded,
otherwise being 1.
"""
if self.no_exclusion:
# safely return 1 if nothing is excluded.
return torch.ones_like(nlist, dtype=torch.int32, device=nlist.device)
nf, nloc, nnei = nlist.shape
nall = atype_ext.shape[1]
# add virtual atom of type ntypes. nf x nall+1
ae = torch.cat(
[
atype_ext,
self.get_ntypes()
* torch.ones([nf, 1], dtype=atype_ext.dtype, device=atype_ext.device),
],
dim=-1,
)
type_i = atype_ext[:, :nloc].view(nf, nloc) * (self.get_ntypes() + 1)
# nf x nloc x nnei
index = torch.where(nlist == -1, nall, nlist).view(nf, nloc * nnei)
type_j = torch.gather(ae, 1, index).view(nf, nloc, nnei)
type_ij = type_i[:, :, None] + type_j
# nf x (nloc x nnei)
type_ij = type_ij.view(nf, nloc * nnei)
mask = self.type_mask[type_ij].view(nf, nloc, nnei)
return mask


def compute_std(sumv2, sumv, sumn, rcut_r):
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
- descriptor_list: list of descriptors.
- descriptor_param: descriptor configs.
"""
super().__init__()
super().__init__(ntypes)
supported_descrpt = ["se_atten", "se_uni"]
descriptor_list = []
for descriptor_param_item in list:
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
whether or not add an type embedding to seq_input.
If no seq_input is given, it has no effect.
"""
super().__init__()
super().__init__(ntypes)
del type
self.epsilon = 1e-4 # protection of 1./nnei
self.rcut = rcut
Expand Down
32 changes: 22 additions & 10 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
ClassVar,
List,
Optional,
Tuple,
)

import numpy as np
Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(
activation_function: str = "tanh",
precision: str = "float64",
resnet_dt: bool = False,
exclude_types: List[Tuple[int, int]] = [],
old_impl: bool = False,
**kwargs,
):
Expand All @@ -63,13 +65,14 @@ def __init__(
rcut,
rcut_smth,
sel,
neuron,
axis_neuron,
set_davg_zero,
activation_function,
precision,
resnet_dt,
old_impl,
neuron=neuron,
axis_neuron=axis_neuron,
set_davg_zero=set_davg_zero,
activation_function=activation_function,
precision=precision,
resnet_dt=resnet_dt,
exclude_types=exclude_types,
old_impl=old_impl,
**kwargs,
)

Expand Down Expand Up @@ -212,14 +215,14 @@ def serialize(self) -> dict:
"precision": RESERVED_PRECISON_DICT[obj.prec],
"embeddings": obj.filter_layers.serialize(),
"env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(),
"exclude_types": obj.exclude_types,
"@variables": {
"davg": obj["davg"].detach().cpu().numpy(),
"dstd": obj["dstd"].detach().cpu().numpy(),
},
## to be updated when the options are supported.
"trainable": True,
"type_one_side": True,
"exclude_types": [],
"spin": None,
}

Expand Down Expand Up @@ -256,6 +259,7 @@ def __init__(
activation_function: str = "tanh",
precision: str = "float64",
resnet_dt: bool = False,
exclude_types: List[Tuple[int, int]] = [],
old_impl: bool = False,
**kwargs,
):
Expand All @@ -268,7 +272,7 @@ def __init__(
- filter_neuron: Number of neurons in each hidden layers of the embedding net.
- axis_neuron: Number of columns of the sub-matrix of the embedding matrix.
"""
super().__init__()
super().__init__(len(sel), exclude_types=exclude_types)
self.rcut = rcut
self.rcut_smth = rcut_smth
self.neuron = neuron
Expand All @@ -280,8 +284,9 @@ def __init__(
self.prec = PRECISION_DICT[self.precision]
self.resnet_dt = resnet_dt
self.old_impl = old_impl

self.exclude_types = exclude_types
self.ntypes = len(sel)

self.sel = sel
self.sec = torch.tensor(
np.append([0], np.cumsum(self.sel)), dtype=int, device=env.DEVICE
Expand Down Expand Up @@ -522,9 +527,16 @@ def forward(
xyz_scatter = torch.zeros(
[nfnl, 4, self.filter_neuron[-1]], dtype=self.prec, device=env.DEVICE
)
# nfnl x nnei
exclude_mask = self.build_type_exclude_mask(nlist, extended_atype).view(
nfnl, -1
)
for ii, ll in enumerate(self.filter_layers.networks):
# nfnl x nt
mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]]
# nfnl x nt x 4
rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :]
rr = rr * mm[:, :, None]
ss = rr[:, :, :1]
# nfnl x nt x ng
gg = ll.forward(ss)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
- filter_neuron: Number of neurons in each hidden layers of the embedding net.
- axis_neuron: Number of columns of the sub-matrix of the embedding matrix.
"""
super().__init__()
super().__init__(ntypes)
del type
self.rcut = rcut
self.rcut_smth = rcut_smth
Expand Down
3 changes: 2 additions & 1 deletion deepmd/tf/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def __init__(
self.activation_function_name = activation_function
self.filter_precision = get_precision(precision)
self.filter_np_precision = get_np_precision(precision)
self.orig_exclude_types = exclude_types
self.exclude_types = set()
for tt in exclude_types:
assert len(tt) == 2
Expand Down Expand Up @@ -1425,7 +1426,7 @@ def serialize(self, suffix: str = "") -> dict:
"resnet_dt": self.filter_resnet_dt,
"trainable": self.trainable,
"type_one_side": self.type_one_side,
"exclude_types": list(self.exclude_types),
"exclude_types": list(self.orig_exclude_types),
"set_davg_zero": self.set_davg_zero,
"activation_function": self.activation_function_name,
"precision": self.filter_precision.name,
Expand Down
Loading

0 comments on commit 8f91aea

Please sign in to comment.