Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

apply PluginVariant and make_plugin_registry to classes #3346

Merged
merged 1 commit into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 5 additions & 28 deletions deepmd/backend/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
)

from deepmd.utils.plugin import (
Plugin,
PluginVariant,
make_plugin_registry,
)

if TYPE_CHECKING:
Expand All @@ -33,7 +33,7 @@
)


class Backend(PluginVariant):
class Backend(PluginVariant, make_plugin_registry("backend")):
r"""General backend class.

Examples
Expand All @@ -44,24 +44,6 @@
... pass
"""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable[[object], object]:
"""Register a backend plugin.

Parameters
----------
key : str
the key of a backend

Returns
-------
Callable[[object], object]
the decorator to register backend
"""
return Backend.__plugins.register(key.lower())

@staticmethod
def get_backend(key: str) -> Type["Backend"]:
"""Get the backend by key.
Expand All @@ -76,12 +58,7 @@
Backend
the backend
"""
try:
backend = Backend.__plugins.get_plugin(key.lower())
except KeyError:
raise KeyError(f"Backend {key} is not registered.")
assert isinstance(backend, type)
return backend
return Backend.get_class_by_type(key)

Check warning on line 61 in deepmd/backend/backend.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/backend.py#L61

Added line #L61 was not covered by tests

@staticmethod
def get_backends() -> Dict[str, Type["Backend"]]:
Expand All @@ -92,7 +69,7 @@
list
all the registered backends
"""
return Backend.__plugins.plugins
return Backend.get_plugins()

Check warning on line 72 in deepmd/backend/backend.py

View check run for this annotation

Codecov / codecov/patch

deepmd/backend/backend.py#L72

Added line #L72 was not covered by tests

@staticmethod
def get_backends_by_feature(
Expand All @@ -112,7 +89,7 @@
"""
return {
key: backend
for key, backend in Backend.__plugins.plugins.items()
for key, backend in Backend.get_backends().items()
if backend.features & feature
}

Expand Down
38 changes: 3 additions & 35 deletions deepmd/dpmodel/descriptor/make_base_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
abstractmethod,
)
from typing import (
Callable,
List,
Optional,
Type,
)

from deepmd.common import (
Expand All @@ -17,7 +15,8 @@
DPPath,
)
from deepmd.utils.plugin import (
Plugin,
PluginVariant,
make_plugin_registry,
)


Expand All @@ -37,45 +36,14 @@ def make_base_descriptor(

"""

class BD(ABC):
class BD(ABC, PluginVariant, make_plugin_registry("descriptor")):
"""Base descriptor provides the interfaces of descriptor."""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable:
"""Register a descriptor plugin.

Parameters
----------
key : str
the key of a descriptor

Returns
-------
Descriptor
the registered descriptor

Examples
--------
>>> @Descriptor.register("some_descrpt")
class SomeDescript(Descriptor):
pass
"""
return BD.__plugins.register(key)

def __new__(cls, *args, **kwargs):
if cls is BD:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)

@classmethod
def get_class_by_type(cls, descrpt_type: str) -> Type["BD"]:
if descrpt_type in BD.__plugins.plugins:
return BD.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)

@abstractmethod
def get_rcut(self) -> float:
"""Returns the cut-off radius."""
Expand Down
38 changes: 3 additions & 35 deletions deepmd/dpmodel/fitting/make_base_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
abstractmethod,
)
from typing import (
Callable,
Dict,
Optional,
Type,
)

from deepmd.common import (
Expand All @@ -17,7 +15,8 @@
FittingOutputDef,
)
from deepmd.utils.plugin import (
Plugin,
PluginVariant,
make_plugin_registry,
)


Expand All @@ -37,45 +36,14 @@ def make_base_fitting(

"""

class BF(ABC):
class BF(ABC, PluginVariant, make_plugin_registry("fitting")):
"""Base fitting provides the interfaces of fitting net."""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable[[object], object]:
"""Register a descriptor plugin.

Parameters
----------
key : str
the key of a descriptor

Returns
-------
callable[[object], object]
the registered descriptor

Examples
--------
>>> @Fitting.register("some_fitting")
class SomeFitting(Fitting):
pass
"""
return BF.__plugins.register(key)

def __new__(cls, *args, **kwargs):
if cls is BF:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
return super().__new__(cls)

@classmethod
def get_class_by_type(cls, fitting_type: str) -> Type["BF"]:
if fitting_type in BF.__plugins.plugins:
return BF.__plugins.plugins[fitting_type]
else:
raise RuntimeError("Unknown fitting type: " + fitting_type)

@abstractmethod
def output_def(self) -> FittingOutputDef:
"""Returns the output def of the fitting net."""
Expand Down
3 changes: 2 additions & 1 deletion deepmd/dpmodel/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,13 @@
)

from deepmd.utils.plugin import (
PluginVariant,
make_plugin_registry,
)


def make_base_model() -> Type[object]:
class BaseBaseModel(ABC, make_plugin_registry("model")):
class BaseBaseModel(ABC, PluginVariant, make_plugin_registry("model")):
"""Base class for final exported model that will be directly used for inference.

The class defines some abstractmethods that will be directly called by the
Expand Down
37 changes: 5 additions & 32 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
abstractmethod,
)
from typing import (
Callable,
Dict,
List,
Optional,
Expand All @@ -22,60 +21,34 @@
from deepmd.pt.utils.env_mat_stat import (
EnvMatStatSe,
)
from deepmd.pt.utils.plugin import (
Plugin,
)
from deepmd.utils.env_mat_stat import (
StatItem,
)
from deepmd.utils.path import (
DPPath,
)
from deepmd.utils.plugin import (

Check warning on line 30 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L30

Added line #L30 was not covered by tests
make_plugin_registry,
)

log = logging.getLogger(__name__)


class DescriptorBlock(torch.nn.Module, ABC):
class DescriptorBlock(torch.nn.Module, ABC, make_plugin_registry("DescriptorBlock")):

Check warning on line 37 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L37

Added line #L37 was not covered by tests
"""The building block of descriptor.
Given the input descriptor, provide with the atomic coordinates,
atomic types and neighbor list, calculate the new descriptor.
"""

__plugins = Plugin()
local_cluster = False

@staticmethod
def register(key: str) -> Callable:
"""Register a DescriptorBlock plugin.

Parameters
----------
key : str
the key of a DescriptorBlock

Returns
-------
DescriptorBlock
the registered DescriptorBlock

Examples
--------
>>> @DescriptorBlock.register("some_descrpt")
class SomeDescript(DescriptorBlock):
pass
"""
return DescriptorBlock.__plugins.register(key)

def __new__(cls, *args, **kwargs):
if cls is DescriptorBlock:
try:
descrpt_type = kwargs["type"]
except KeyError:
raise KeyError("the type of DescriptorBlock should be set by `type`")
if descrpt_type in DescriptorBlock.__plugins.plugins:
cls = DescriptorBlock.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown DescriptorBlock type: " + descrpt_type)
cls = cls.get_class_by_type(descrpt_type)

Check warning on line 51 in deepmd/pt/model/descriptor/descriptor.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/model/descriptor/descriptor.py#L51

Added line #L51 was not covered by tests
return super().__new__(cls)

@abstractmethod
Expand Down
38 changes: 4 additions & 34 deletions deepmd/tf/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
)
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Expand All @@ -21,12 +20,14 @@
tf,
)
from deepmd.tf.utils import (
Plugin,
PluginVariant,
)
from deepmd.utils.plugin import (
make_plugin_registry,
)


class Descriptor(PluginVariant):
class Descriptor(PluginVariant, make_plugin_registry("descriptor")):
r"""The abstract class for descriptors. All specific descriptors should
be based on this class.

Expand All @@ -45,37 +46,6 @@ class Descriptor(PluginVariant):
that can be called by other classes.
"""

__plugins = Plugin()

@staticmethod
def register(key: str) -> Callable:
"""Register a descriptor plugin.

Parameters
----------
key : str
the key of a descriptor

Returns
-------
Descriptor
the registered descriptor

Examples
--------
>>> @Descriptor.register("some_descrpt")
class SomeDescript(Descriptor):
pass
"""
return Descriptor.__plugins.register(key)

@classmethod
def get_class_by_type(cls, descrpt_type: str):
if descrpt_type in Descriptor.__plugins.plugins:
return Descriptor.__plugins.plugins[descrpt_type]
else:
raise RuntimeError("Unknown descriptor type: " + descrpt_type)

def __new__(cls, *args, **kwargs):
if cls is Descriptor:
cls = cls.get_class_by_type(j_get_type(kwargs, cls.__name__))
Expand Down
Loading