diff --git a/deepmd/pt/infer/deep_eval.py b/deepmd/pt/infer/deep_eval.py index 538dc65371..0a77a38135 100644 --- a/deepmd/pt/infer/deep_eval.py +++ b/deepmd/pt/infer/deep_eval.py @@ -598,3 +598,58 @@ def eval_typeebd(self) -> np.ndarray: def get_model_def_script(self) -> str: """Get model defination script.""" return self.model_def_script + + def eval_descriptor( + self, + coords: np.ndarray, + cells: Optional[np.ndarray], + atom_types: np.ndarray, + fparam: Optional[np.ndarray] = None, + aparam: Optional[np.ndarray] = None, + **kwargs: Any, + ) -> np.ndarray: + """Evaluate descriptors by using this DP. + + Parameters + ---------- + coords + The coordinates of atoms. + The array should be of size nframes x natoms x 3 + cells + The cell of the region. + If None then non-PBC is assumed, otherwise using PBC. + The array should be of size nframes x 9 + atom_types + The atom types + The list should contain natoms ints + fparam + The frame parameter. + The array can be of size : + - nframes x dim_fparam. + - dim_fparam. Then all frames are assumed to be provided with the same fparam. + aparam + The atomic parameter + The array can be of size : + - nframes x natoms x dim_aparam. + - natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam. + - dim_aparam. Then all frames and atoms are provided with the same aparam. + + Returns + ------- + descriptor + Descriptors. + """ + model = self.dp.model["Default"] + model.set_eval_descriptor_hook(True) + self.eval( + coords, + cells, + atom_types, + atomic=False, + fparam=fparam, + aparam=aparam, + **kwargs, + ) + descriptor = model.eval_descriptor() + model.set_eval_descriptor_hook(False) + return to_numpy_array(descriptor) diff --git a/deepmd/pt/model/atomic_model/dp_atomic_model.py b/deepmd/pt/model/atomic_model/dp_atomic_model.py index 936a1fead3..edb1253234 100644 --- a/deepmd/pt/model/atomic_model/dp_atomic_model.py +++ b/deepmd/pt/model/atomic_model/dp_atomic_model.py @@ -62,6 +62,19 @@ def __init__( self.sel = self.descriptor.get_sel() self.fitting_net = fitting super().init_out_stat() + self.enable_eval_descriptor_hook = False + self.eval_descriptor_list = [] + + eval_descriptor_list: list[torch.Tensor] + + def set_eval_descriptor_hook(self, enable: bool) -> None: + """Set the hook for evaluating descriptor and clear the cache for descriptor list.""" + self.enable_eval_descriptor_hook = enable + self.eval_descriptor_list = [] + + def eval_descriptor(self) -> torch.Tensor: + """Evaluate the descriptor.""" + return torch.concat(self.eval_descriptor_list) @torch.jit.export def fitting_output_def(self) -> FittingOutputDef: @@ -192,6 +205,8 @@ def forward_atomic( comm_dict=comm_dict, ) assert descriptor is not None + if self.enable_eval_descriptor_hook: + self.eval_descriptor_list.append(descriptor) # energy, force fit_ret = self.fitting_net( descriptor, diff --git a/deepmd/pt/model/model/dp_model.py b/deepmd/pt/model/model/dp_model.py index 8659526c49..bd278ed787 100644 --- a/deepmd/pt/model/model/dp_model.py +++ b/deepmd/pt/model/model/dp_model.py @@ -3,6 +3,8 @@ Optional, ) +import torch + from deepmd.pt.model.descriptor.base_descriptor import ( BaseDescriptor, ) @@ -52,3 +54,13 @@ def get_fitting_net(self): def get_descriptor(self): """Get the descriptor.""" return self.atomic_model.descriptor + + @torch.jit.export + def set_eval_descriptor_hook(self, enable: bool) -> None: + """Set the hook for evaluating descriptor and clear the cache for descriptor list.""" + self.atomic_model.set_eval_descriptor_hook(enable) + + @torch.jit.export + def eval_descriptor(self) -> torch.Tensor: + """Evaluate the descriptor.""" + return self.atomic_model.eval_descriptor() diff --git a/source/tests/infer/test_models.py b/source/tests/infer/test_models.py index 6b62e994aa..2b0f292046 100644 --- a/source/tests/infer/test_models.py +++ b/source/tests/infer/test_models.py @@ -153,8 +153,6 @@ def test_1frame_atm(self): def test_descriptor(self): _, extension = self.param - if extension == ".pth": - self.skipTest("eval_descriptor not supported for PyTorch models") for ii, result in enumerate(self.case.results): if result.descriptor is None: continue