Skip to content

Commit

Permalink
refact: pt: mv all plugin support to base descriptor. (#3340)
Browse files Browse the repository at this point in the history
thus pt reusing the dp code.

---------

Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
wanghan-iapcm and Han Wang authored Feb 27, 2024
1 parent 4f70073 commit 3e6b507
Show file tree
Hide file tree
Showing 9 changed files with 27 additions and 180 deletions.
6 changes: 3 additions & 3 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from deepmd.dpmodel import (
FittingOutputDef,
)
from deepmd.pt.model.descriptor.descriptor import (
Descriptor,
from deepmd.pt.model.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.pt.model.task.base_fitting import (
BaseFitting,
Expand Down Expand Up @@ -101,7 +101,7 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
descriptor_obj = Descriptor.deserialize(data["descriptor"])
descriptor_obj = BaseDescriptor.deserialize(data["descriptor"])
fitting_obj = BaseFitting.deserialize(data["fitting"])
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
return obj
Expand Down
2 changes: 0 additions & 2 deletions deepmd/pt/model/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from .descriptor import (
Descriptor,
DescriptorBlock,
make_default_type_embedding,
)
Expand Down Expand Up @@ -29,7 +28,6 @@
)

__all__ = [
"Descriptor",
"DescriptorBlock",
"make_default_type_embedding",
"DescrptBlockSeA",
Expand Down
93 changes: 0 additions & 93 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,10 @@
Dict,
List,
Optional,
Type,
)

import torch

from deepmd.common import (
j_get_type,
)
from deepmd.pt.model.network.network import (
TypeEmbedNet,
)
Expand All @@ -36,98 +32,9 @@
DPPath,
)

from .base_descriptor import (
BaseDescriptor,
)

log = logging.getLogger(__name__)


class Descriptor(torch.nn.Module, BaseDescriptor):
"""The descriptor.
Given the atomic coordinates, atomic types and neighbor list,
calculate the descriptor.
"""

__plugins = Plugin()
local_cluster = False

@staticmethod
def register(key: str) -> Callable:
"""Register a descriptor plugin.
Parameters
----------
key : str
the key of a descriptor
Returns
-------
Descriptor
the registered descriptor
Examples
--------
>>> @Descriptor.register("some_descrpt")
class SomeDescript(Descriptor):
pass
"""
return Descriptor.__plugins.register(key)

@classmethod
def get_data_process_key(cls, config):
"""
Get the keys for the data preprocess.
Usually need the information of rcut and sel.
TODO Need to be deprecated when the dataloader has been cleaned up.
"""
if cls is not Descriptor:
raise NotImplementedError("get_data_process_key is not implemented!")
descrpt_type = config["type"]
return Descriptor.__plugins.plugins[descrpt_type].get_data_process_key(config)

@property
def data_stat_key(self):
"""
Get the keys for the data statistic of the descriptor.
Return a list of statistic names needed, such as "sumr", "suma" or "sumn".
"""
raise NotImplementedError("data_stat_key is not implemented!")

def __new__(cls, *args, **kwargs):
if cls is Descriptor:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)

@classmethod
def get_class_by_type(cls, descrpt_type: str) -> Type["Descriptor"]:
if descrpt_type in Descriptor.__plugins.plugins:
return Descriptor.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)

@classmethod
def deserialize(cls, data: dict) -> "Descriptor":
"""Deserialize the model.
There is no suffix in a native DP model, but it is important
for the TF backend.
Parameters
----------
data : dict
The serialized data
Returns
-------
Descriptor
The deserialized descriptor
"""
if cls is Descriptor:
return Descriptor.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)


class DescriptorBlock(torch.nn.Module, ABC):
"""The building block of descriptor.
Given the input descriptor, provide with the atomic coordinates,
Expand Down
31 changes: 6 additions & 25 deletions deepmd/pt/model/descriptor/dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,24 +6,24 @@

import torch

from deepmd.pt.model.descriptor import (
Descriptor,
)
from deepmd.pt.model.network.network import (
TypeEmbedNet,
)
from deepmd.utils.path import (
DPPath,
)

from .base_descriptor import (
BaseDescriptor,
)
from .se_atten import (
DescrptBlockSeAtten,
)


@Descriptor.register("dpa1")
@Descriptor.register("se_atten")
class DescrptDPA1(Descriptor):
@BaseDescriptor.register("dpa1")
@BaseDescriptor.register("se_atten")
class DescrptDPA1(BaseDescriptor, torch.nn.Module):
def __init__(
self,
rcut,
Expand Down Expand Up @@ -131,25 +131,6 @@ def dim_emb(self):
def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None):
return self.se_atten.compute_input_stats(merged, path)

@classmethod
def get_data_process_key(cls, config):
"""
Get the keys for the data preprocess.
Usually need the information of rcut and sel.
TODO Need to be deprecated when the dataloader has been cleaned up.
"""
descrpt_type = config["type"]
assert descrpt_type in ["dpa1", "se_atten"]
return {"sel": config["sel"], "rcut": config["rcut"]}

@property
def data_stat_key(self):
"""
Get the keys for the data statistic of the descriptor.
Return a list of statistic names needed, such as "sumr", "suma" or "sumn".
"""
return ["sumr", "suma", "sumn", "sumr2", "suma2"]

def serialize(self) -> dict:
"""Serialize the obj to dict."""
raise NotImplementedError
Expand Down
32 changes: 5 additions & 27 deletions deepmd/pt/model/descriptor/dpa2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@

import torch

from deepmd.pt.model.descriptor import (
Descriptor,
)
from deepmd.pt.model.network.network import (
Identity,
Linear,
Expand All @@ -22,6 +19,9 @@
DPPath,
)

from .base_descriptor import (
BaseDescriptor,
)
from .repformers import (
DescrptBlockRepformers,
)
Expand All @@ -30,8 +30,8 @@
)


@Descriptor.register("dpa2")
class DescrptDPA2(Descriptor):
@BaseDescriptor.register("dpa2")
class DescrptDPA2(torch.nn.Module, BaseDescriptor):
def __init__(
self,
ntypes: int,
Expand Down Expand Up @@ -306,28 +306,6 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None)
]
descrpt.compute_input_stats(merged_tmp)

@classmethod
def get_data_process_key(cls, config):
"""
Get the keys for the data preprocess.
Usually need the information of rcut and sel.
TODO Need to be deprecated when the dataloader has been cleaned up.
"""
descrpt_type = config["type"]
assert descrpt_type in ["dpa2"]
return {
"sel": [config["repinit_nsel"], config["repformer_nsel"]],
"rcut": [config["repinit_rcut"], config["repformer_rcut"]],
}

@property
def data_stat_key(self):
"""
Get the keys for the data statistic of the descriptor.
Return a list of statistic names needed, such as "sumr", "suma" or "sumn".
"""
return ["sumr", "suma", "sumn", "sumr2", "suma2"]

def serialize(self) -> dict:
"""Serialize the obj to dict."""
raise NotImplementedError
Expand Down
6 changes: 3 additions & 3 deletions deepmd/pt/model/descriptor/gaussian_lcc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import torch
import torch.nn as nn

from deepmd.pt.model.descriptor import (
Descriptor,
from deepmd.pt.model.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.pt.model.network.network import (
Evoformer3bEncoder,
Expand All @@ -23,7 +23,7 @@
)


class DescrptGaussianLcc(Descriptor):
class DescrptGaussianLcc(torch.nn.Module, BaseDescriptor):
def __init__(
self,
rcut,
Expand Down
28 changes: 6 additions & 22 deletions deepmd/pt/model/descriptor/se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import torch

from deepmd.pt.model.descriptor import (
Descriptor,
DescriptorBlock,
prod_env_mat_se_a,
)
Expand Down Expand Up @@ -52,9 +51,13 @@
PairExcludeMask,
)

from .base_descriptor import (
BaseDescriptor,
)


@Descriptor.register("se_e2_a")
class DescrptSeA(Descriptor):
@BaseDescriptor.register("se_e2_a")
class DescrptSeA(BaseDescriptor, torch.nn.Module):
def __init__(
self,
rcut,
Expand Down Expand Up @@ -127,25 +130,6 @@ def compute_input_stats(self, merged: List[dict], path: Optional[DPPath] = None)
"""Update mean and stddev for descriptor elements."""
return self.sea.compute_input_stats(merged, path)

@classmethod
def get_data_process_key(cls, config):
"""
Get the keys for the data preprocess.
Usually need the information of rcut and sel.
TODO Need to be deprecated when the dataloader has been cleaned up.
"""
descrpt_type = config["type"]
assert descrpt_type in ["se_e2_a"]
return {"sel": config["sel"], "rcut": config["rcut"]}

@property
def data_stat_key(self):
"""
Get the keys for the data statistic of the descriptor.
Return a list of statistic names needed, such as "sumr", "suma" or "sumn".
"""
return ["sumr", "suma", "sumn", "sumr2", "suma2"]

def forward(
self,
coord_ext: torch.Tensor,
Expand Down
8 changes: 4 additions & 4 deletions deepmd/pt/model/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
DPAtomicModel,
PairTabAtomicModel,
)
from deepmd.pt.model.descriptor.descriptor import (
Descriptor,
from deepmd.pt.model.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.pt.model.task import (
Fitting,
Expand Down Expand Up @@ -48,7 +48,7 @@ def get_zbl_model(model_params):
ntypes = len(model_params["type_map"])
# descriptor
model_params["descriptor"]["ntypes"] = ntypes
descriptor = Descriptor(**model_params["descriptor"])
descriptor = BaseDescriptor(**model_params["descriptor"])
# fitting
fitting_net = model_params.get("fitting_net", None)
fitting_net["type"] = fitting_net.get("type", "ener")
Expand Down Expand Up @@ -84,7 +84,7 @@ def get_model(model_params):
ntypes = len(model_params["type_map"])
# descriptor
model_params["descriptor"]["ntypes"] = ntypes
descriptor = Descriptor(**model_params["descriptor"])
descriptor = BaseDescriptor(**model_params["descriptor"])
# fitting
fitting_net = model_params.get("fitting_net", None)
fitting_net["type"] = fitting_net.get("type", "ener")
Expand Down
1 change: 0 additions & 1 deletion source/tests/pt/model/test_polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ def test_rot(self):

ret0 = ft0(rd0, extended_atype, gr0, fparam=ifp, aparam=iap)
res.append(ret0["foo"])
print(res[1].shape)
np.testing.assert_allclose(
to_numpy_array(res[1]),
to_numpy_array(
Expand Down

0 comments on commit 3e6b507

Please sign in to comment.