Skip to content

Commit

Permalink
store type in fitting serialization data (#3331)
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz authored Feb 24, 2024
1 parent 15df69b commit 91049df
Show file tree
Hide file tree
Showing 17 changed files with 116 additions and 62 deletions.
11 changes: 3 additions & 8 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import sys
from typing import (
Dict,
List,
Expand All @@ -12,9 +11,8 @@
from deepmd.dpmodel.descriptor.base_descriptor import (
BaseDescriptor,
)
from deepmd.dpmodel.fitting import ( # noqa # TODO: should import all fittings!
EnergyFittingNet,
InvarFitting,
from deepmd.dpmodel.fitting.base_fitting import (
BaseFitting,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
Expand Down Expand Up @@ -135,16 +133,13 @@ def serialize(self) -> dict:
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
"fitting_name": self.fitting.__class__.__name__,
}

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
descriptor_obj = BaseDescriptor.deserialize(data["descriptor"])
fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize(
data["fitting"]
)
fitting_obj = BaseFitting.deserialize(data["fitting"])
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
return obj

Expand Down
5 changes: 5 additions & 0 deletions deepmd/dpmodel/fitting/dipole_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
from deepmd.dpmodel import (
DEFAULT_PRECISION,
)
from deepmd.dpmodel.fitting.base_fitting import (
BaseFitting,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
OutputVariableDef,
Expand All @@ -22,6 +25,7 @@
)


@BaseFitting.register("dipole")
@fitting_check_output
class DipoleFitting(GeneralFitting):
r"""Fitting rotationally equivariant diploe of the system.
Expand Down Expand Up @@ -142,6 +146,7 @@ def _net_out_dim(self):

def serialize(self) -> dict:
data = super().serialize()
data["type"] = "dipole"
data["embedding_width"] = self.embedding_width
data["old_impl"] = self.old_impl
data["r_differentiable"] = self.r_differentiable
Expand Down
8 changes: 8 additions & 0 deletions deepmd/dpmodel/fitting/ener_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
)


@InvarFitting.register("ener")
class EnergyFittingNet(InvarFitting):
def __init__(
self,
Expand Down Expand Up @@ -70,3 +71,10 @@ def deserialize(cls, data: dict) -> "GeneralFitting":
data.pop("var_name")
data.pop("dim_out")
return super().deserialize(data)

def serialize(self) -> dict:
"""Serialize the fitting to dict."""
return {
**super().serialize(),
"type": "ener",
}
3 changes: 3 additions & 0 deletions deepmd/dpmodel/fitting/general_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def __getitem__(self, key):
def serialize(self) -> dict:
"""Serialize the fitting to dict."""
return {
"@class": "Fitting",
"var_name": self.var_name,
"ntypes": self.ntypes,
"dim_descrpt": self.dim_descrpt,
Expand Down Expand Up @@ -240,6 +241,8 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data: dict) -> "GeneralFitting":
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
variables = data.pop("@variables")
nets = data.pop("nets")
obj = cls(**data)
Expand Down
2 changes: 2 additions & 0 deletions deepmd/dpmodel/fitting/invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)


@GeneralFitting.register("invar")
@fitting_check_output
class InvarFitting(GeneralFitting):
r"""Fitting the energy (or a rotationally invariant porperty of `dim_out`) of the system. The force and the virial can also be trained.
Expand Down Expand Up @@ -162,6 +163,7 @@ def __init__(

def serialize(self) -> dict:
data = super().serialize()
data["type"] = "invar"
data["dim_out"] = self.dim_out
data["atom_ener"] = self.atom_ener
return data
Expand Down
66 changes: 61 additions & 5 deletions deepmd/dpmodel/fitting/make_base_fitting.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractclassmethod,
abstractmethod,
)
from typing import (
Callable,
Dict,
Optional,
Type,
)

from deepmd.common import (
j_get_type,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
)
from deepmd.utils.plugin import (
Plugin,
)


def make_base_fitting(
Expand All @@ -33,6 +40,42 @@ def make_base_fitting(
class BF(ABC):
"""Base fitting provides the interfaces of fitting net."""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable[[object], object]:
"""Register a descriptor plugin.
Parameters
----------
key : str
the key of a descriptor
Returns
-------
callable[[object], object]
the registered descriptor
Examples
--------
>>> @Fitting.register("some_fitting")
class SomeFitting(Fitting):
pass
"""
return BF.__plugins.register(key)

def __new__(cls, *args, **kwargs):
if cls is BF:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)

@classmethod
def get_class_by_type(cls, fitting_type: str) -> Type["BF"]:
if fitting_type in BF.__plugins.plugins:
return BF.__plugins.plugins[fitting_type]
else:
raise RuntimeError("Unknown fitting type: " + fitting_type)

@abstractmethod
def output_def(self) -> FittingOutputDef:
"""Returns the output def of the fitting net."""
Expand Down Expand Up @@ -65,10 +108,23 @@ def serialize(self) -> dict:
"""Serialize the obj to dict."""
pass

@abstractclassmethod
def deserialize(cls):
"""Deserialize from a dict."""
pass
@classmethod
def deserialize(cls, data: dict) -> "BF":
"""Deserialize the fitting.
Parameters
----------
data : dict
The serialized data
Returns
-------
BF
The deserialized fitting
"""
if cls is BF:
return BF.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

setattr(BF, fwd_method_name, BF.fwd)
delattr(BF, "fwd")
Expand Down
5 changes: 5 additions & 0 deletions deepmd/dpmodel/fitting/polarizability_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@
from deepmd.dpmodel import (
DEFAULT_PRECISION,
)
from deepmd.dpmodel.fitting.base_fitting import (
BaseFitting,
)
from deepmd.dpmodel.output_def import (
FittingOutputDef,
OutputVariableDef,
Expand All @@ -25,6 +28,7 @@
)


@BaseFitting.register("polar")
@fitting_check_output
class PolarFitting(GeneralFitting):
r"""Fitting rotationally equivariant polarizability of the system.
Expand Down Expand Up @@ -166,6 +170,7 @@ def _net_out_dim(self):

def serialize(self) -> dict:
data = super().serialize()
data["type"] = "polar"
data["embedding_width"] = self.embedding_width
data["old_impl"] = self.old_impl
data["fit_diag"] = self.fit_diag
Expand Down
11 changes: 3 additions & 8 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import logging
import sys
from typing import (
Dict,
List,
Expand All @@ -16,9 +15,8 @@
from deepmd.pt.model.descriptor.descriptor import (
Descriptor,
)
from deepmd.pt.model.task.ener import ( # noqa # TODO: should import all fittings!
EnergyFittingNet,
InvarFitting,
from deepmd.pt.model.task.base_fitting import (
BaseFitting,
)
from deepmd.pt.utils.utils import (
dict_to_device,
Expand Down Expand Up @@ -98,16 +96,13 @@ def serialize(self) -> dict:
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting_net.serialize(),
"fitting_name": self.fitting_net.__class__.__name__,
}

@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
descriptor_obj = Descriptor.deserialize(data["descriptor"])
fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize(
data["fitting"]
)
fitting_obj = BaseFitting.deserialize(data["fitting"])
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
return obj

Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from .fitting import (
Fitting,
)
from .polarizability import (
PolarFittingNet,
)
from .type_predict import (
TypePredictNet,
)
Expand All @@ -31,4 +34,5 @@
"Fitting",
"BaseFitting",
"TypePredictNet",
"PolarFittingNet",
]
2 changes: 2 additions & 0 deletions deepmd/pt/model/task/dipole.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
log = logging.getLogger(__name__)


@GeneralFitting.register("dipole")
class DipoleFittingNet(GeneralFitting):
"""Construct a dipole fitting net.
Expand Down Expand Up @@ -111,6 +112,7 @@ def _net_out_dim(self):

def serialize(self) -> dict:
data = super().serialize()
data["type"] = "dipole"
data["embedding_width"] = self.embedding_width
data["old_impl"] = self.old_impl
data["r_differentiable"] = self.r_differentiable
Expand Down
9 changes: 9 additions & 0 deletions deepmd/pt/model/task/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
log = logging.getLogger(__name__)


@GeneralFitting.register("invar")
@fitting_check_output
class InvarFitting(GeneralFitting):
"""Construct a fitting net for energy.
Expand Down Expand Up @@ -129,6 +130,7 @@ def _net_out_dim(self):

def serialize(self) -> dict:
data = super().serialize()
data["type"] = "invar"
data["dim_out"] = self.dim_out
data["atom_ener"] = self.atom_ener
return data
Expand Down Expand Up @@ -238,6 +240,13 @@ def deserialize(cls, data: dict) -> "GeneralFitting":
data.pop("dim_out")
return super().deserialize(data)

def serialize(self) -> dict:
"""Serialize the fitting to dict."""
return {
**super().serialize(),
"type": "ener",
}


@Fitting.register("direct_force")
@Fitting.register("direct_force_ener")
Expand Down
Loading

0 comments on commit 91049df

Please sign in to comment.