Skip to content

Commit

Permalink
feat(pt): support DeepEval.eval_descriptor
Browse files Browse the repository at this point in the history
Fix #4112.

Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Oct 14, 2024
1 parent a1f8672 commit fabe092
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 2 deletions.
47 changes: 47 additions & 0 deletions deepmd/pt/infer/deep_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,3 +598,50 @@ def eval_typeebd(self) -> np.ndarray:
def get_model_def_script(self) -> str:
"""Get model defination script."""
return self.model_def_script

def eval_descriptor(
self,
coords: np.ndarray,
cells: Optional[np.ndarray],
atom_types: np.ndarray,
fparam: Optional[np.ndarray] = None,
aparam: Optional[np.ndarray] = None,
**kwargs: Any,
) -> np.ndarray:
"""Evaluate descriptors by using this DP.
Parameters
----------
coords
The coordinates of atoms.
The array should be of size nframes x natoms x 3
cells
The cell of the region.
If None then non-PBC is assumed, otherwise using PBC.
The array should be of size nframes x 9
atom_types
The atom types
The list should contain natoms ints
fparam
The frame parameter.
The array can be of size :
- nframes x dim_fparam.
- dim_fparam. Then all frames are assumed to be provided with the same fparam.
aparam
The atomic parameter
The array can be of size :
- nframes x natoms x dim_aparam.
- natoms x dim_aparam. Then all frames are assumed to be provided with the same aparam.
- dim_aparam. Then all frames and atoms are provided with the same aparam.
Returns
-------
descriptor
Descriptors.
"""
model = self.dp.model["Default"]
model.set_eval_descriptor_hook(True)
self.eval(coords, cells, atom_types, fparam, aparam, **kwargs)
descriptor = model.eval_descriptor()
model.set_eval_descriptor_hook(False)
return to_numpy_array(descriptor)
15 changes: 15 additions & 0 deletions deepmd/pt/model/atomic_model/dp_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ def __init__(
self.sel = self.descriptor.get_sel()
self.fitting_net = fitting
super().init_out_stat()
self.enable_eval_descriptor_hook = False
self.eval_descriptor_list = []

eval_descriptor_list: list[torch.Tensor]

def set_eval_descriptor_hook(self, enable: bool) -> None:
"""Set the hook for evaluating descriptor and clear the cache for descriptor list."""
self.enable_eval_descriptor_hook = enable
self.eval_descriptor_list = []

def eval_descriptor(self) -> torch.Tensor:
"""Evaluate the descriptor."""
return torch.concat(self.eval_descriptor_list)

@torch.jit.export
def fitting_output_def(self) -> FittingOutputDef:
Expand Down Expand Up @@ -192,6 +205,8 @@ def forward_atomic(
comm_dict=comm_dict,
)
assert descriptor is not None
if self.enable_eval_descriptor_hook:
self.eval_descriptor_list.append(descriptor)
# energy, force
fit_ret = self.fitting_net(
descriptor,
Expand Down
12 changes: 12 additions & 0 deletions deepmd/pt/model/model/dp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
Optional,
)

import torch

from deepmd.pt.model.descriptor.base_descriptor import (
BaseDescriptor,
)
Expand Down Expand Up @@ -52,3 +54,13 @@ def get_fitting_net(self):
def get_descriptor(self):
"""Get the descriptor."""
return self.atomic_model.descriptor

@torch.jit.export
def set_eval_descriptor_hook(self, enable: bool) -> None:
"""Set the hook for evaluating descriptor and clear the cache for descriptor list."""
self.atomic_model.set_eval_descriptor_hook(enable)

@torch.jit.export
def eval_descriptor(self) -> torch.Tensor:
"""Evaluate the descriptor."""
return self.atomic_model.eval_descriptor()
2 changes: 0 additions & 2 deletions source/tests/infer/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,6 @@ def test_1frame_atm(self):

def test_descriptor(self):
_, extension = self.param
if extension == ".pth":
self.skipTest("eval_descriptor not supported for PyTorch models")
for ii, result in enumerate(self.case.results):
if result.descriptor is None:
continue
Expand Down

0 comments on commit fabe092

Please sign in to comment.