Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add BaseModel; store type in serialization #3335

Merged
merged 17 commits into from
Feb 27, 2024
4 changes: 4 additions & 0 deletions deepmd/dpmodel/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,8 @@ def forward_atomic(

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "standard",
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting.serialize(),
Expand All @@ -138,6 +140,8 @@ def serialize(self) -> dict:
@classmethod
def deserialize(cls, data) -> "DPAtomicModel":
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
descriptor_obj = BaseDescriptor.deserialize(data["descriptor"])
fitting_obj = BaseFitting.deserialize(data["fitting"])
obj = cls(descriptor_obj, fitting_obj, type_map=data["type_map"])
Expand Down
11 changes: 11 additions & 0 deletions deepmd/dpmodel/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
import sys
from abc import (
abstractmethod,
Expand Down Expand Up @@ -182,12 +183,17 @@ def fitting_output_def(self) -> FittingOutputDef:
@staticmethod
def serialize(models) -> dict:
return {
"@class": "Model",
"type": "linear",
"models": [model.serialize() for model in models],
"model_name": [model.__class__.__name__ for model in models],
}

@staticmethod
def deserialize(data) -> List[BaseAtomicModel]:
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
model_names = data["model_name"]
models = [
getattr(sys.modules[__name__], name).deserialize(model)
Expand Down Expand Up @@ -263,6 +269,8 @@ def __init__(

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "zbl",
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
Expand All @@ -271,6 +279,9 @@ def serialize(self) -> dict:

@classmethod
def deserialize(cls, data) -> "DPZBLLinearAtomicModel":
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
sw_rmin = data["sw_rmin"]
sw_rmax = data["sw_rmax"]
smin_alpha = data["smin_alpha"]
Expand Down
12 changes: 11 additions & 1 deletion deepmd/dpmodel/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import copy
from typing import (
Dict,
List,
Expand Down Expand Up @@ -105,10 +106,19 @@ def mixed_types(self) -> bool:
return True

def serialize(self) -> dict:
return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel}
return {
"@class": "Model",
"type": "pairtab",
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}

@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
data = copy.deepcopy(data)
data.pop("@class")
data.pop("type")
rcut = data["rcut"]
sel = data["sel"]
tab = PairTab.deserialize(data["tab"])
Expand Down
6 changes: 3 additions & 3 deletions deepmd/dpmodel/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@

import numpy as np

from deepmd.dpmodel.model.dp_model import (
DPModel,
from deepmd.dpmodel.model.base_model import (
BaseModel,
)
from deepmd.dpmodel.output_def import (
ModelOutputDef,
Expand Down Expand Up @@ -85,7 +85,7 @@ def __init__(
self.model_path = model_file

model_data = load_dp_model(model_file)
self.dp = DPModel.deserialize(model_data["model"])
self.dp = BaseModel.deserialize(model_data["model"])
self.rcut = self.dp.get_rcut()
self.type_map = self.dp.get_type_map()
if isinstance(auto_batch_size, bool):
Expand Down
160 changes: 160 additions & 0 deletions deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from abc import (
ABC,
abstractmethod,
)
from typing import (
Any,
Callable,
List,
Type,
)

from deepmd.common import (
j_get_type,
)
from deepmd.utils.plugin import (
Plugin,
)


class BaseBaseModel(ABC):
"""Base class for final exported model that will be directly used for inference.

The class defines some abstractmethods that will be directly called by the
inference interface. If the final model class inherbits some of those methods
njzjz marked this conversation as resolved.
Show resolved Hide resolved
from other classes, `BaseModel` should be inherited as the last class to ensure
the correct method resolution order.

This class is backend-indepedent.

See Also
--------
deepmd.dpmodel.model.base_model.BaseModel
BaseModel class for DPModel backend.
"""

@abstractmethod
def __call__(self, *args: Any, **kwds: Any) -> Any:
"""Inference method.

Parameters
----------
*args : Any
The input data for inference.
**kwds : Any
The input data for inference.

Returns
-------
Any
The output of the inference.
"""
pass

Check warning on line 53 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L53

Added line #L53 was not covered by tests

@abstractmethod
def get_type_map(self) -> List[str]:
"""Get the type map."""

@abstractmethod
def get_rcut(self):
"""Get the cut-off radius."""

@abstractmethod
def get_dim_fparam(self):
"""Get the number (dimension) of frame parameters of this atomic model."""

@abstractmethod
def get_dim_aparam(self):
"""Get the number (dimension) of atomic parameters of this atomic model."""

@abstractmethod
def get_sel_type(self) -> List[int]:
"""Get the selected atom types of this model.

Only atoms with selected atom types have atomic contribution
to the result of the model.
If returning an empty list, all atom types are selected.
"""

@abstractmethod
def is_aparam_nall(self) -> bool:
"""Check whether the shape of atomic parameters is (nframes, nall, ndim).

If False, the shape is (nframes, nloc, ndim).
"""

@abstractmethod
def model_output_type(self) -> str:
"""Get the output type for the model."""


class BaseModel(BaseBaseModel):
"""Base class for final exported model that will be directly used for inference.

The class defines some abstractmethods that will be directly called by the
inference interface. If the final model class inherbits some of those methods
from other classes, `BaseModel` should be inherited as the last class to ensure
the correct method resolution order.

This class is for the DPModel backend.

See Also
--------
deepmd.dpmodel.model.base_model.BaseBaseModel
Backend-independent BaseModel class.
"""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable[[object], object]:
"""Register a descriptor plugin.

Parameters
----------
key : str
the key of a descriptor

Returns
-------
callable[[object], object]
the registered descriptor

Examples
--------
>>> @Fitting.register("some_fitting")
class SomeFitting(Fitting):
pass
"""
return BaseModel.__plugins.register(key)

def __new__(cls, *args, **kwargs):
if cls is BaseModel:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))

Check warning on line 134 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L134

Added line #L134 was not covered by tests
return super().__new__(cls)

@classmethod
def get_class_by_type(cls, model_type: str) -> Type["BaseModel"]:
if model_type in BaseModel.__plugins.plugins:
return BaseModel.__plugins.plugins[model_type]
else:
raise RuntimeError("Unknown model type: " + model_type)

Check warning on line 142 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L142

Added line #L142 was not covered by tests

@classmethod
def deserialize(cls, data: dict) -> "BaseModel":
"""Deserialize the model.

Parameters
----------
data : dict
The serialized data

Returns
-------
BaseModel
The deserialized model
"""
if cls is BaseModel:
return BaseModel.get_class_by_type(data["type"]).deserialize(data)
raise NotImplementedError("Not implemented in class %s" % cls.__name__)

Check warning on line 160 in deepmd/dpmodel/model/base_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/dpmodel/model/base_model.py#L160

Added line #L160 was not covered by tests
6 changes: 5 additions & 1 deletion deepmd/dpmodel/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,16 @@
from deepmd.dpmodel.atomic_model import (
DPAtomicModel,
)
from deepmd.dpmodel.model.base_model import (
BaseModel,
)

from .make_model import (
make_model,
)


# use "class" to resolve "Variable not allowed in type expression"
class DPModel(make_model(DPAtomicModel)):
@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel), BaseModel):
Fixed Show fixed Hide fixed
pass
2 changes: 2 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ def mixed_types(self) -> bool:

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "standard",
"type_map": self.type_map,
"descriptor": self.descriptor.serialize(),
"fitting": self.fitting_net.serialize(),
Expand Down
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def fitting_output_def(self) -> FittingOutputDef:
@staticmethod
def serialize(models) -> dict:
return {
"@class": "Model",
"type": "linear",
"models": [model.serialize() for model in models],
"model_name": [model.__class__.__name__ for model in models],
}
Expand Down Expand Up @@ -301,6 +303,8 @@ def __init__(

def serialize(self) -> dict:
return {
"@class": "Model",
"type": "zbl",
"models": LinearAtomicModel.serialize([self.dp_model, self.zbl_model]),
"sw_rmin": self.sw_rmin,
"sw_rmax": self.sw_rmax,
Expand Down
8 changes: 7 additions & 1 deletion deepmd/pt/model/atomic_model/pairtab_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,13 @@ def mixed_types(self) -> bool:
return True

def serialize(self) -> dict:
return {"tab": self.tab.serialize(), "rcut": self.rcut, "sel": self.sel}
return {
"@class": "Model",
"type": "pairtab",
"tab": self.tab.serialize(),
"rcut": self.rcut,
"sel": self.sel,
}

@classmethod
def deserialize(cls, data) -> "PairTabAtomicModel":
Expand Down
38 changes: 37 additions & 1 deletion deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,45 @@
from deepmd.pt.model.atomic_model import (
DPAtomicModel,
)
from deepmd.pt.model.model.model import (
BaseModel,
)
from deepmd.pt.model.task.dipole import (
DipoleFittingNet,
)
from deepmd.pt.model.task.ener import (
EnergyFittingNet,
)
from deepmd.pt.model.task.polarizability import (
PolarFittingNet,
)

from .make_model import (
make_model,
)

DPModel = make_model(DPAtomicModel)

@BaseModel.register("standard")
class DPModel(make_model(DPAtomicModel), BaseModel):
Fixed Show fixed Hide fixed

Check warning

Code scanning / CodeQL

Conflicting attributes in base classes Warning

Base classes have conflicting values for attribute 'compute_or_load_stat':
Function compute_or_load_stat
and
Function compute_or_load_stat
.
Base classes have conflicting values for attribute 'compute_or_load_stat': Function compute_or_load_stat and
Function compute_or_load_stat
.
def __new__(cls, descriptor, fitting, *args, **kwargs):
from deepmd.pt.model.model.dipole_model import (
DipoleModel,
)
Comment on lines +26 to +28

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.pt.model.model.dipole_model
begins an import cycle.
from deepmd.pt.model.model.ener_model import (
EnergyModel,
)
Comment on lines +29 to +31

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.pt.model.model.ener_model
begins an import cycle.
from deepmd.pt.model.model.polar_model import (
PolarModel,
)
Comment on lines +32 to +34

Check notice

Code scanning / CodeQL

Cyclic import Note

Import of module
deepmd.pt.model.model.polar_model
begins an import cycle.

# according to the fitting network to decide the type of the model
if cls is DPModel:
# map fitting to model
if isinstance(fitting, EnergyFittingNet):
cls = EnergyModel
elif isinstance(fitting, DipoleFittingNet):
cls = DipoleModel

Check warning on line 42 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L42

Added line #L42 was not covered by tests
elif isinstance(fitting, PolarFittingNet):
cls = PolarModel

Check warning on line 44 in deepmd/pt/model/model/dp_model.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/model/dp_model.py#L44

Added line #L44 was not covered by tests
# else: unknown fitting type, fall back to DPModel
return super().__new__(cls)
6 changes: 5 additions & 1 deletion deepmd/pt/model/model/dp_zbl_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
from deepmd.pt.model.atomic_model import (
DPZBLLinearAtomicModel,
)
from deepmd.pt.model.model.model import (
BaseModel,
)

from .make_model import (
make_model,
Expand All @@ -17,7 +20,8 @@
DPZBLModel_ = make_model(DPZBLLinearAtomicModel)


class DPZBLModel(DPZBLModel_):
@BaseModel.register("zbl")
class DPZBLModel(DPZBLModel_, BaseModel):

Check warning

Code scanning / CodeQL

Conflicting attributes in base classes Warning

Base classes have conflicting values for attribute 'compute_or_load_stat':
Function compute_or_load_stat
and
Function compute_or_load_stat
.
Base classes have conflicting values for attribute 'compute_or_load_stat': Function compute_or_load_stat and
Function compute_or_load_stat
.
model_type = "ener"

def __init__(
Expand Down
Loading
Loading