diff --git a/openmmml/mlpotential.py b/openmmml/mlpotential.py index 04e1292..5155b24 100644 --- a/openmmml/mlpotential.py +++ b/openmmml/mlpotential.py @@ -6,7 +6,7 @@ Biological Structures at Stanford, funded under the NIH Roadmap for Medical Research, grant U54 GM072970. See https://simtk.org. -Portions copyright (c) 2021 Stanford University and the Authors. +Portions copyright (c) 2021-2024 Stanford University and the Authors. Authors: Peter Eastman Contributors: @@ -34,6 +34,11 @@ import openmm.unit as unit from copy import deepcopy from typing import Dict, Iterable, Optional +import sys +if sys.version_info < (3, 10): + from importlib_metadata import entry_points +else: + from importlib.metadata import entry_points class MLPotentialImplFactory(object): @@ -41,7 +46,11 @@ class MLPotentialImplFactory(object): If you are defining a new potential function, you need to create subclasses of MLPotentialImpl and MLPotentialImplFactory, and register an instance of - the factory by calling MLPotential.registerImplFactory(). + the factory by calling MLPotential.registerImplFactory(). Alternatively, + if a Python package creates an entry point in the group "openmmml.potentials", + the potential will be registered automatically. The entry point name is the + name of the potential function, and the value should be the name of the + MLPotentialImplFactory subclass. """ def createImpl(self, name: str, **args) -> "MLPotentialImpl": @@ -417,3 +426,9 @@ def registerImplFactory(name: str, factory: MLPotentialImplFactory): a factory object that will be used to create MLPotentialImpl objects """ MLPotential._implFactories[name] = factory + + +# Register any potential functions defined by entry points. + +for potential in entry_points(group='openmmml.potentials'): + MLPotential.registerImplFactory(potential.name, potential.load()()) \ No newline at end of file diff --git a/openmmml/models/anipotential.py b/openmmml/models/anipotential.py index 4fe4a60..9e6f2d4 100644 --- a/openmmml/models/anipotential.py +++ b/openmmml/models/anipotential.py @@ -142,6 +142,3 @@ def forward(self, positions, boxvectors: Optional[torch.Tensor] = None): force.setForceGroup(forceGroup) force.setUsesPeriodicBoundaryConditions(is_periodic) system.addForce(force) - -MLPotential.registerImplFactory('ani1ccx', ANIPotentialImplFactory()) -MLPotential.registerImplFactory('ani2x', ANIPotentialImplFactory()) diff --git a/openmmml/models/macepotential.py b/openmmml/models/macepotential.py index 7301766..0cf1397 100644 --- a/openmmml/models/macepotential.py +++ b/openmmml/models/macepotential.py @@ -405,9 +405,3 @@ def forward( force.setForceGroup(forceGroup) force.setUsesPeriodicBoundaryConditions(isPeriodic) system.addForce(force) - - -MLPotential.registerImplFactory("mace", MACEPotentialImplFactory()) -MLPotential.registerImplFactory("mace-off23-small", MACEPotentialImplFactory()) -MLPotential.registerImplFactory("mace-off23-medium", MACEPotentialImplFactory()) -MLPotential.registerImplFactory("mace-off23-large", MACEPotentialImplFactory()) diff --git a/setup.py b/setup.py index c8ff0a6..f315ca5 100644 --- a/setup.py +++ b/setup.py @@ -38,5 +38,16 @@ classifiers=CLASSIFIERS.splitlines(), packages=find_packages(), zip_safe=False, - install_requires=['numpy', 'openmm >= 7.5']) + install_requires=['numpy', 'openmm >= 7.5'], + entry_points={ + 'openmmml.potentials': [ + 'ani1ccx = openmmml.models.anipotential:ANIPotentialImplFactory', + 'ani2x = openmmml.models.anipotential:ANIPotentialImplFactory', + 'mace = openmmml.models.macepotential:MACEPotentialImplFactory', + 'mace-off23-small = openmmml.models.macepotential:MACEPotentialImplFactory', + 'mace-off23-medium = openmmml.models.macepotential:MACEPotentialImplFactory', + 'mace-off23-large = openmmml.models.macepotential:MACEPotentialImplFactory' + ] + } +)