Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: dp and pt: implement exclude types in descriptor se_a #3280

Merged
merged 7 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions deepmd/dpmodel/descriptor/exclude_mask.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
List,
Tuple,
)

import numpy as np


class ExcludeMask:
"""Computes the atom type exclusion mask."""

def __init__(
self,
ntypes: int,
exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
self.ntypes = ntypes
self.exclude_types = set()
for tt in exclude_types:
assert len(tt) == 2
self.exclude_types.add((tt[0], tt[1]))
self.exclude_types.add((tt[1], tt[0]))
# ntypes + 1 for nlist masks
self.type_mask = np.array(
[
[
1 if (tt_i, tt_j) not in self.exclude_types else 0
for tt_i in range(ntypes + 1)
]
for tt_j in range(ntypes + 1)
],
dtype=np.int32,
)
# (ntypes+1 x ntypes+1)
self.type_mask = self.type_mask.reshape([-1])

def build_type_exclude_mask(
self,
nlist: np.ndarray,
atype_ext: np.ndarray,
):
"""Compute type exclusion mask.

Parameters
----------
nlist
The neighbor list. shape: nf x nloc x nnei
atype_ext
The extended aotm types. shape: nf x nall

Returns
-------
mask
The type exclusion mask of shape: nf x nloc x nnei.
Element [ff,ii,jj] being 0 if type(ii), type(nlist[ff,ii,jj]) is excluded,
otherwise being 1.

"""
if len(self.exclude_types) == 0:
# safely return 1 if nothing is excluded.
return np.ones_like(nlist, dtype=np.int32)
nf, nloc, nnei = nlist.shape
nall = atype_ext.shape[1]
# add virtual atom of type ntypes. nf x nall+1
ae = np.concatenate(
[atype_ext, self.ntypes * np.ones([nf, 1], dtype=atype_ext.dtype)], axis=-1
)
type_i = atype_ext[:, :nloc].reshape(nf, nloc) * (self.ntypes + 1)
# nf x nloc x nnei
index = np.where(nlist == -1, nall, nlist).reshape(nf, nloc * nnei)
type_j = np.take_along_axis(ae, index, axis=1).reshape(nf, nloc, nnei)
type_ij = type_i[:, :, None] + type_j
# nf x (nloc x nnei)
type_ij = type_ij.reshape(nf, nloc * nnei)
mask = self.type_mask[type_ij].reshape(nf, nloc, nnei)
return mask
9 changes: 7 additions & 2 deletions deepmd/dpmodel/descriptor/se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
from .base_descriptor import (
BaseDescriptor,
)
from .exclude_mask import (
ExcludeMask,
)


class DescrptSeA(NativeOP, BaseDescriptor):
Expand Down Expand Up @@ -140,8 +143,6 @@ def __init__(
## seed, uniform_seed, multi_task, not included.
if not type_one_side:
raise NotImplementedError("type_one_side == False not implemented")
if exclude_types != []:
raise NotImplementedError("exclude_types is not implemented")
if spin is not None:
raise NotImplementedError("spin is not implemented")

Expand All @@ -159,6 +160,7 @@ def __init__(
self.activation_function = activation_function
self.precision = precision
self.spin = spin
self.emask = ExcludeMask(self.ntypes, self.exclude_types)

in_dim = 1 # not considiering type embedding
self.embeddings = NetworkCollection(
Expand Down Expand Up @@ -292,8 +294,11 @@ def call(

ng = self.neuron[-1]
gr = np.zeros([nf, nloc, ng, 4])
exclude_mask = self.emask.build_type_exclude_mask(nlist, atype_ext)
for tt in range(self.ntypes):
mm = exclude_mask[:, :, sec[tt] : sec[tt + 1]]
tr = rr[:, :, sec[tt] : sec[tt + 1], :]
tr = tr * mm[:, :, :, None]
ss = tr[..., 0:1]
gg = self.cal_g(ss, tt)
# nf x nloc x ng x 4
Expand Down
80 changes: 79 additions & 1 deletion deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
Callable,
List,
Optional,
Set,
Tuple,
Union,
)

Expand All @@ -20,6 +22,9 @@
from deepmd.pt.utils.plugin import (
Plugin,
)
from deepmd.pt.utils.utils import (
to_torch_tensor,
)

from .base_descriptor import (
BaseDescriptor,
Expand Down Expand Up @@ -206,6 +211,32 @@
__plugins = Plugin()
local_cluster = False

def __init__(
self,
ntypes: int,
exclude_types: List[Tuple[int, int]] = [],
):
super().__init__()
_exclude_types: Set[Tuple[int, int]] = set()
for tt in exclude_types:
assert len(tt) == 2
_exclude_types.add((tt[0], tt[1]))
_exclude_types.add((tt[1], tt[0]))
# ntypes + 1 for nlist masks
self.type_mask = np.array(
[
[
1 if (tt_i, tt_j) not in _exclude_types else 0
for tt_i in range(ntypes + 1)
]
for tt_j in range(ntypes + 1)
],
dtype=np.int32,
)
# (ntypes+1 x ntypes+1)
self.type_mask = to_torch_tensor(self.type_mask).view([-1])
self.no_exclusion = len(_exclude_types) == 0

@staticmethod
def register(key: str) -> Callable:
"""Register a DescriptorBlock plugin.
Expand Down Expand Up @@ -332,7 +363,54 @@
mapping: Optional[torch.Tensor] = None,
):
"""Calculate DescriptorBlock."""
raise NotImplementedError
pass

Check warning on line 366 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L366

Added line #L366 was not covered by tests

# may have a better place for this method...
def build_type_exclude_mask(
self,
nlist: torch.Tensor,
atype_ext: torch.Tensor,
) -> torch.Tensor:
wanghan-iapcm marked this conversation as resolved.
Show resolved Hide resolved
"""Compute type exclusion mask.

Parameters
----------
nlist
The neighbor list. shape: nf x nloc x nnei
atype_ext
The extended aotm types. shape: nf x nall

Returns
-------
mask
The type exclusion mask of shape: nf x nloc x nnei.
Element [ff,ii,jj] being 0 if type(ii), type(nlist[ff,ii,jj]) is excluded,
otherwise being 1.

"""
if self.no_exclusion:
# safely return 1 if nothing is excluded.
return torch.ones_like(nlist, dtype=torch.int32, device=nlist.device)
nf, nloc, nnei = nlist.shape
nall = atype_ext.shape[1]
# add virtual atom of type ntypes. nf x nall+1
ae = torch.cat(
[
atype_ext,
self.get_ntypes()
* torch.ones([nf, 1], dtype=atype_ext.dtype, device=atype_ext.device),
],
dim=-1,
)
type_i = atype_ext[:, :nloc].view(nf, nloc) * (self.get_ntypes() + 1)
# nf x nloc x nnei
index = torch.where(nlist == -1, nall, nlist).view(nf, nloc * nnei)
type_j = torch.gather(ae, 1, index).view(nf, nloc, nnei)
type_ij = type_i[:, :, None] + type_j
# nf x (nloc x nnei)
type_ij = type_ij.view(nf, nloc * nnei)
mask = self.type_mask[type_ij].view(nf, nloc, nnei)
return mask


def compute_std(sumv2, sumv, sumn, rcut_r):
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/hybrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(
- descriptor_list: list of descriptors.
- descriptor_param: descriptor configs.
"""
super().__init__()
super().__init__(ntypes)
supported_descrpt = ["se_atten", "se_uni"]
descriptor_list = []
for descriptor_param_item in list:
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/repformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(
whether or not add an type embedding to seq_input.
If no seq_input is given, it has no effect.
"""
super().__init__()
super().__init__(ntypes)
del type
self.epsilon = 1e-4 # protection of 1./nnei
self.rcut = rcut
Expand Down
32 changes: 22 additions & 10 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
ClassVar,
List,
Optional,
Tuple,
)

import numpy as np
Expand Down Expand Up @@ -55,6 +56,7 @@ def __init__(
activation_function: str = "tanh",
precision: str = "float64",
resnet_dt: bool = False,
exclude_types: List[Tuple[int, int]] = [],
old_impl: bool = False,
**kwargs,
):
Expand All @@ -63,13 +65,14 @@ def __init__(
rcut,
rcut_smth,
sel,
neuron,
axis_neuron,
set_davg_zero,
activation_function,
precision,
resnet_dt,
old_impl,
neuron=neuron,
axis_neuron=axis_neuron,
set_davg_zero=set_davg_zero,
activation_function=activation_function,
precision=precision,
resnet_dt=resnet_dt,
exclude_types=exclude_types,
old_impl=old_impl,
**kwargs,
)

Expand Down Expand Up @@ -212,14 +215,14 @@ def serialize(self) -> dict:
"precision": RESERVED_PRECISON_DICT[obj.prec],
"embeddings": obj.filter_layers.serialize(),
"env_mat": DPEnvMat(obj.rcut, obj.rcut_smth).serialize(),
"exclude_types": obj.exclude_types,
"@variables": {
"davg": obj["davg"].detach().cpu().numpy(),
"dstd": obj["dstd"].detach().cpu().numpy(),
},
## to be updated when the options are supported.
"trainable": True,
"type_one_side": True,
"exclude_types": [],
"spin": None,
}

Expand Down Expand Up @@ -256,6 +259,7 @@ def __init__(
activation_function: str = "tanh",
precision: str = "float64",
resnet_dt: bool = False,
exclude_types: List[Tuple[int, int]] = [],
old_impl: bool = False,
**kwargs,
):
Expand All @@ -268,7 +272,7 @@ def __init__(
- filter_neuron: Number of neurons in each hidden layers of the embedding net.
- axis_neuron: Number of columns of the sub-matrix of the embedding matrix.
"""
super().__init__()
super().__init__(len(sel), exclude_types=exclude_types)
self.rcut = rcut
self.rcut_smth = rcut_smth
self.neuron = neuron
Expand All @@ -280,8 +284,9 @@ def __init__(
self.prec = PRECISION_DICT[self.precision]
self.resnet_dt = resnet_dt
self.old_impl = old_impl

self.exclude_types = exclude_types
self.ntypes = len(sel)

self.sel = sel
self.sec = torch.tensor(
np.append([0], np.cumsum(self.sel)), dtype=int, device=env.DEVICE
Expand Down Expand Up @@ -522,9 +527,16 @@ def forward(
xyz_scatter = torch.zeros(
[nfnl, 4, self.filter_neuron[-1]], dtype=self.prec, device=env.DEVICE
)
# nfnl x nnei
exclude_mask = self.build_type_exclude_mask(nlist, extended_atype).view(
nfnl, -1
)
for ii, ll in enumerate(self.filter_layers.networks):
# nfnl x nt
mm = exclude_mask[:, self.sec[ii] : self.sec[ii + 1]]
# nfnl x nt x 4
rr = dmatrix[:, self.sec[ii] : self.sec[ii + 1], :]
rr = rr * mm[:, :, None]
ss = rr[:, :, :1]
# nfnl x nt x ng
gg = ll.forward(ss)
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/model/descriptor/se_atten.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def __init__(
- filter_neuron: Number of neurons in each hidden layers of the embedding net.
- axis_neuron: Number of columns of the sub-matrix of the embedding matrix.
"""
super().__init__()
super().__init__(ntypes)
del type
self.rcut = rcut
self.rcut_smth = rcut_smth
Expand Down
3 changes: 2 additions & 1 deletion deepmd/tf/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def __init__(
self.activation_function_name = activation_function
self.filter_precision = get_precision(precision)
self.filter_np_precision = get_np_precision(precision)
self.orig_exclude_types = exclude_types
self.exclude_types = set()
for tt in exclude_types:
assert len(tt) == 2
Expand Down Expand Up @@ -1425,7 +1426,7 @@ def serialize(self, suffix: str = "") -> dict:
"resnet_dt": self.filter_resnet_dt,
"trainable": self.trainable,
"type_one_side": self.type_one_side,
"exclude_types": list(self.exclude_types),
"exclude_types": list(self.orig_exclude_types),
"set_davg_zero": self.set_davg_zero,
"activation_function": self.activation_function_name,
"precision": self.filter_precision.name,
Expand Down
Loading
Loading