diff --git a/deepmd/common.py b/deepmd/common.py index 691cc262df..d7e485788b 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -313,3 +313,24 @@ def get_hash(obj) -> str: object to hash """ return sha1(json.dumps(obj).encode("utf-8")).hexdigest() + + +def j_get_type(data: dict, class_name: str = "object") -> str: + """Get the type from the data. + + Parameters + ---------- + data : dict + the data + class_name : str, optional + the name of the class for error message, by default "object" + + Returns + ------- + str + the type + """ + try: + return data["type"] + except KeyError as e: + raise KeyError(f"the type of the {class_name} should be set by `type`") from e diff --git a/deepmd/dpmodel/atomic_model/dp_atomic_model.py b/deepmd/dpmodel/atomic_model/dp_atomic_model.py index 220c072765..99fee6e050 100644 --- a/deepmd/dpmodel/atomic_model/dp_atomic_model.py +++ b/deepmd/dpmodel/atomic_model/dp_atomic_model.py @@ -9,8 +9,8 @@ import numpy as np -from deepmd.dpmodel.descriptor import ( # noqa # TODO: should import all descriptors! - DescrptSeA, +from deepmd.dpmodel.descriptor.base_descriptor import ( + BaseDescriptor, ) from deepmd.dpmodel.fitting import ( # noqa # TODO: should import all fittings! EnergyFittingNet, @@ -135,16 +135,13 @@ def serialize(self) -> dict: "type_map": self.type_map, "descriptor": self.descriptor.serialize(), "fitting": self.fitting.serialize(), - "descriptor_name": self.descriptor.__class__.__name__, "fitting_name": self.fitting.__class__.__name__, } @classmethod def deserialize(cls, data) -> "DPAtomicModel": data = copy.deepcopy(data) - descriptor_obj = getattr( - sys.modules[__name__], data["descriptor_name"] - ).deserialize(data["descriptor"]) + descriptor_obj = BaseDescriptor.deserialize(data["descriptor"]) fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize( data["fitting"] ) diff --git a/deepmd/dpmodel/descriptor/base_descriptor.py b/deepmd/dpmodel/descriptor/base_descriptor.py index ca403d7f8e..7429d3f213 100644 --- a/deepmd/dpmodel/descriptor/base_descriptor.py +++ b/deepmd/dpmodel/descriptor/base_descriptor.py @@ -1,4 +1,5 @@ # SPDX-License-Identifier: LGPL-3.0-or-later + import numpy as np from .make_base_descriptor import ( diff --git a/deepmd/dpmodel/descriptor/make_base_descriptor.py b/deepmd/dpmodel/descriptor/make_base_descriptor.py index 7bd553db9e..2cdb5abd52 100644 --- a/deepmd/dpmodel/descriptor/make_base_descriptor.py +++ b/deepmd/dpmodel/descriptor/make_base_descriptor.py @@ -1,17 +1,24 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from abc import ( ABC, - abstractclassmethod, abstractmethod, ) from typing import ( + Callable, List, Optional, + Type, ) +from deepmd.common import ( + j_get_type, +) from deepmd.utils.path import ( DPPath, ) +from deepmd.utils.plugin import ( + Plugin, +) def make_base_descriptor( @@ -33,6 +40,42 @@ def make_base_descriptor( class BD(ABC): """Base descriptor provides the interfaces of descriptor.""" + __plugins = Plugin() + + @staticmethod + def register(key: str) -> Callable: + """Register a descriptor plugin. + + Parameters + ---------- + key : str + the key of a descriptor + + Returns + ------- + Descriptor + the registered descriptor + + Examples + -------- + >>> @Descriptor.register("some_descrpt") + class SomeDescript(Descriptor): + pass + """ + return BD.__plugins.register(key) + + def __new__(cls, *args, **kwargs): + if cls is BD: + cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__)) + return super().__new__(cls) + + @classmethod + def get_class_by_type(cls, descrpt_type: str) -> Type["BD"]: + if descrpt_type in BD.__plugins.plugins: + return BD.__plugins.plugins[descrpt_type] + else: + raise RuntimeError("Unknown descriptor type: " + descrpt_type) + @abstractmethod def get_rcut(self) -> float: """Returns the cut-off radius.""" @@ -95,10 +138,23 @@ def serialize(self) -> dict: """Serialize the obj to dict.""" pass - @abstractclassmethod - def deserialize(cls): - """Deserialize from a dict.""" - pass + @classmethod + def deserialize(cls, data: dict) -> "BD": + """Deserialize the model. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + BD + The deserialized descriptor + """ + if cls is BD: + return BD.get_class_by_type(data["type"]).deserialize(data) + raise NotImplementedError("Not implemented in class %s" % cls.__name__) setattr(BD, fwd_method_name, BD.fwd) delattr(BD, "fwd") diff --git a/deepmd/dpmodel/descriptor/se_e2_a.py b/deepmd/dpmodel/descriptor/se_e2_a.py index 1802b5bab6..be2ed12394 100644 --- a/deepmd/dpmodel/descriptor/se_e2_a.py +++ b/deepmd/dpmodel/descriptor/se_e2_a.py @@ -34,6 +34,7 @@ ) +@BaseDescriptor.register("se_e2_a") class DescrptSeA(NativeOP, BaseDescriptor): r"""DeepPot-SE constructed from all information (both angular and radial) of atomic configurations. The embedding takes the distance between atoms as input. @@ -313,6 +314,8 @@ def call( def serialize(self) -> dict: """Serialize the descriptor to dict.""" return { + "@class": "Descriptor", + "type": "se_e2_a", "rcut": self.rcut, "rcut_smth": self.rcut_smth, "sel": self.sel, @@ -339,6 +342,8 @@ def serialize(self) -> dict: def deserialize(cls, data: dict) -> "DescrptSeA": """Deserialize from dict.""" data = copy.deepcopy(data) + data.pop("@class", None) + data.pop("type", None) variables = data.pop("@variables") embeddings = data.pop("embeddings") env_mat = data.pop("env_mat") diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index dafc9e109e..98bf6c0fde 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -13,8 +13,8 @@ from deepmd.dpmodel import ( FittingOutputDef, ) -from deepmd.pt.model.descriptor.se_a import ( # noqa # TODO: should import all descriptors!!! - DescrptSeA, +from deepmd.pt.model.descriptor.descriptor import ( + Descriptor, ) from deepmd.pt.model.task.ener import ( # noqa # TODO: should import all fittings! EnergyFittingNet, @@ -98,16 +98,13 @@ def serialize(self) -> dict: "type_map": self.type_map, "descriptor": self.descriptor.serialize(), "fitting": self.fitting_net.serialize(), - "descriptor_name": self.descriptor.__class__.__name__, "fitting_name": self.fitting_net.__class__.__name__, } @classmethod def deserialize(cls, data) -> "DPAtomicModel": data = copy.deepcopy(data) - descriptor_obj = getattr( - sys.modules[__name__], data["descriptor_name"] - ).deserialize(data["descriptor"]) + descriptor_obj = Descriptor.deserialize(data["descriptor"]) fitting_obj = getattr(sys.modules[__name__], data["fitting_name"]).deserialize( data["fitting"] ) diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 16659e444d..091f2b1e20 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -9,10 +9,14 @@ Dict, List, Optional, + Type, ) import torch +from deepmd.common import ( + j_get_type, +) from deepmd.pt.model.network.network import ( TypeEmbedNet, ) @@ -92,16 +96,37 @@ def data_stat_key(self): def __new__(cls, *args, **kwargs): if cls is Descriptor: - try: - descrpt_type = kwargs["type"] - except KeyError: - raise KeyError("the type of descriptor should be set by `type`") - if descrpt_type in Descriptor.__plugins.plugins: - cls = Descriptor.__plugins.plugins[descrpt_type] - else: - raise RuntimeError("Unknown descriptor type: " + descrpt_type) + cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__)) return super().__new__(cls) + @classmethod + def get_class_by_type(cls, descrpt_type: str) -> Type["Descriptor"]: + if descrpt_type in Descriptor.__plugins.plugins: + return Descriptor.__plugins.plugins[descrpt_type] + else: + raise RuntimeError("Unknown descriptor type: " + descrpt_type) + + @classmethod + def deserialize(cls, data: dict) -> "Descriptor": + """Deserialize the model. + + There is no suffix in a native DP model, but it is important + for the TF backend. + + Parameters + ---------- + data : dict + The serialized data + + Returns + ------- + Descriptor + The deserialized descriptor + """ + if cls is Descriptor: + return Descriptor.get_class_by_type(data["type"]).deserialize(data) + raise NotImplementedError("Not implemented in class %s" % cls.__name__) + class DescriptorBlock(torch.nn.Module, ABC): """The building block of descriptor. diff --git a/deepmd/pt/model/descriptor/se_a.py b/deepmd/pt/model/descriptor/se_a.py index 4134c963da..0550488ecf 100644 --- a/deepmd/pt/model/descriptor/se_a.py +++ b/deepmd/pt/model/descriptor/se_a.py @@ -193,6 +193,8 @@ def set_stat_mean_and_stddev( def serialize(self) -> dict: obj = self.sea return { + "@class": "Descriptor", + "type": "se_e2_a", "rcut": obj.rcut, "rcut_smth": obj.rcut_smth, "sel": obj.sel, @@ -219,6 +221,8 @@ def serialize(self) -> dict: @classmethod def deserialize(cls, data: dict) -> "DescrptSeA": data = data.copy() + data.pop("@class", None) + data.pop("type", None) variables = data.pop("@variables") embeddings = data.pop("embeddings") env_mat = data.pop("env_mat") diff --git a/deepmd/tf/descriptor/descriptor.py b/deepmd/tf/descriptor/descriptor.py index 768e233245..48329ceb48 100644 --- a/deepmd/tf/descriptor/descriptor.py +++ b/deepmd/tf/descriptor/descriptor.py @@ -13,6 +13,9 @@ import numpy as np +from deepmd.common import ( + j_get_type, +) from deepmd.tf.env import ( GLOBAL_TF_FLOAT_PRECISION, tf, @@ -67,11 +70,7 @@ class SomeDescript(Descriptor): return Descriptor.__plugins.register(key) @classmethod - def get_class_by_input(cls, input: dict): - try: - descrpt_type = input["type"] - except KeyError: - raise KeyError("the type of descriptor should be set by `type`") + def get_class_by_type(cls, descrpt_type: str): if descrpt_type in Descriptor.__plugins.plugins: return Descriptor.__plugins.plugins[descrpt_type] else: @@ -79,7 +78,7 @@ def get_class_by_input(cls, input: dict): def __new__(cls, *args, **kwargs): if cls is Descriptor: - cls = cls.get_class_by_input(kwargs) + cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__)) return super().__new__(cls) @abstractmethod @@ -507,7 +506,7 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict): The local data refer to the current class """ # call subprocess - cls = cls.get_class_by_input(local_jdata) + cls = cls.get_class_by_type(j_get_type(local_jdata, cls.__name__)) return cls.update_sel(global_jdata, local_jdata) @classmethod @@ -530,7 +529,9 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": The deserialized descriptor """ if cls is Descriptor: - return Descriptor.get_class_by_input(data).deserialize(data, suffix=suffix) + return Descriptor.get_class_by_type( + j_get_type(data, cls.__name__) + ).deserialize(data, suffix=suffix) raise NotImplementedError("Not implemented in class %s" % cls.__name__) def serialize(self, suffix: str = "") -> dict: diff --git a/deepmd/tf/descriptor/se_a.py b/deepmd/tf/descriptor/se_a.py index 0e0cb664a4..e1b7258c63 100644 --- a/deepmd/tf/descriptor/se_a.py +++ b/deepmd/tf/descriptor/se_a.py @@ -1368,6 +1368,8 @@ def deserialize(cls, data: dict, suffix: str = ""): if cls is not DescrptSeA: raise NotImplementedError("Not implemented in class %s" % cls.__name__) data = data.copy() + data.pop("@class", None) + data.pop("type", None) embedding_net_variables = cls.deserialize_network( data.pop("embeddings"), suffix=suffix ) @@ -1418,6 +1420,8 @@ def serialize(self, suffix: str = "") -> dict: # but instead a part of the input data. Maybe the interface should be refactored... return { + "@class": "Descriptor", + "type": "se_e2_a", "rcut": self.rcut_r, "rcut_smth": self.rcut_r_smth, "sel": self.sel_a, diff --git a/deepmd/tf/fit/fitting.py b/deepmd/tf/fit/fitting.py index 458765f7c1..1c0d3d83ac 100644 --- a/deepmd/tf/fit/fitting.py +++ b/deepmd/tf/fit/fitting.py @@ -9,6 +9,9 @@ Type, ) +from deepmd.common import ( + j_get_type, +) from deepmd.dpmodel.utils.network import ( FittingNet, NetworkCollection, @@ -52,23 +55,19 @@ class SomeFitting(Fitting): return Fitting.__plugins.register(key) @classmethod - def get_class_by_input(cls, data: dict) -> Type["Fitting"]: - """Get the fitting class by the input data. + def get_class_by_type(cls, fitting_type: str) -> Type["Fitting"]: + """Get the fitting class by the input type. Parameters ---------- - data : dict - The input data + fitting_type : str + The input type Returns ------- Fitting The fitting class """ - try: - fitting_type = data["type"] - except KeyError: - raise KeyError("the type of fitting should be set by `type`") if fitting_type in Fitting.__plugins.plugins: cls = Fitting.__plugins.plugins[fitting_type] else: @@ -77,7 +76,7 @@ def get_class_by_input(cls, data: dict) -> Type["Fitting"]: def __new__(cls, *args, **kwargs): if cls is Fitting: - cls = cls.get_class_by_input(kwargs) + cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__)) return super().__new__(cls) @property @@ -148,7 +147,9 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Fitting": The deserialized fitting """ if cls is Fitting: - return Fitting.get_class_by_input(data).deserialize(data, suffix=suffix) + return Fitting.get_class_by_type( + j_get_type(data, cls.__name__) + ).deserialize(data, suffix=suffix) raise NotImplementedError("Not implemented in class %s" % cls.__name__) def serialize(self, suffix: str = "") -> dict: diff --git a/deepmd/tf/model/model.py b/deepmd/tf/model/model.py index 65413a87c1..73339a450f 100644 --- a/deepmd/tf/model/model.py +++ b/deepmd/tf/model/model.py @@ -14,6 +14,9 @@ Union, ) +from deepmd.common import ( + j_get_type, +) from deepmd.tf.descriptor.descriptor import ( Descriptor, ) @@ -92,13 +95,13 @@ class Model(ABC): """ @classmethod - def get_class_by_input(cls, input: dict): - """Get the class by input data. + def get_class_by_type(cls, model_type: str): + """Get the class by input type. Parameters ---------- - input : dict - The input data + model_type : str + The input type """ # infer model type by fitting_type from deepmd.tf.model.frozen import ( @@ -117,7 +120,6 @@ def get_class_by_input(cls, input: dict): PairwiseDPRc, ) - model_type = input.get("type", "standard") if model_type == "standard": return StandardModel elif model_type == "multi": @@ -136,7 +138,7 @@ def get_class_by_input(cls, input: dict): def __new__(cls, *args, **kwargs): if cls is Model: # init model - cls = cls.get_class_by_input(kwargs) + cls = cls.get_class_by_type(kwargs.get("type", "standard")) return cls.__new__(cls, *args, **kwargs) return super().__new__(cls) @@ -575,7 +577,7 @@ def update_sel(cls, global_jdata: dict, local_jdata: dict) -> dict: dict The updated local data """ - cls = cls.get_class_by_input(local_jdata) + cls = cls.get_class_by_type(local_jdata.get("type", "standard")) return cls.update_sel(global_jdata, local_jdata) @classmethod @@ -598,7 +600,9 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Model": The deserialized Model """ if cls is Model: - return Model.get_class_by_input(data).deserialize(data) + return Model.get_class_by_type(data.get("type", "standard")).deserialize( + data + ) raise NotImplementedError("Not implemented in class %s" % cls.__name__) def serialize(self, suffix: str = "") -> dict: @@ -646,7 +650,9 @@ def __new__(cls, *args, **kwargs): if cls is StandardModel: if isinstance(kwargs["fitting_net"], dict): - fitting_type = Fitting.get_class_by_input(kwargs["fitting_net"]) + fitting_type = Fitting.get_class_by_type( + j_get_type(kwargs["fitting_net"], cls.__name__) + ) elif isinstance(kwargs["fitting_net"], Fitting): fitting_type = type(kwargs["fitting_net"]) else: @@ -808,9 +814,6 @@ def deserialize(cls, data: dict, suffix: str = "") -> "Descriptor": """ data = copy.deepcopy(data) - data["descriptor"]["type"] = { - "DescrptSeA": "se_e2_a", - }[data.pop("descriptor_name")] data["fitting"]["type"] = { "EnergyFittingNet": "ener", }[data.pop("fitting_name")] @@ -843,7 +846,6 @@ def serialize(self, suffix: str = "") -> dict: "type_map": self.type_map, "descriptor": self.descrpt.serialize(suffix=suffix), "fitting": self.fitting.serialize(suffix=suffix), - "descriptor_name": self.descrpt.__class__.__name__, "fitting_name": {"EnerFitting": "EnergyFittingNet"}[ self.fitting.__class__.__name__ ],