From d81443ba37adab034d7ffa9e0523b141cd96a1c4 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 5 Mar 2024 17:41:53 -0500 Subject: [PATCH] convert exclude_types to sel_type This can give the correct result for `dp test`, ```sh cd examples/water_tensor/dipole dp --pt train input_torch.json dp --pt freeze dp test -m frozen_model.pth -s validation_data/global_system/ ``` Signed-off-by: Jinzhe Zeng --- deepmd/dpmodel/fitting/general_fitting.py | 2 +- deepmd/pt/model/task/fitting.py | 10 +++++++++- .../tests/common/dpmodel/test_fitting_invar_fitting.py | 4 ++++ source/tests/pt/model/test_ener_fitting.py | 1 + 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/deepmd/dpmodel/fitting/general_fitting.py b/deepmd/dpmodel/fitting/general_fitting.py index c004814b60..01bf107c63 100644 --- a/deepmd/dpmodel/fitting/general_fitting.py +++ b/deepmd/dpmodel/fitting/general_fitting.py @@ -179,7 +179,7 @@ def get_sel_type(self) -> List[int]: to the result of the model. If returning an empty list, all atom types are selected. """ - return [] + return [ii for ii in range(self.ntypes) if ii not in self.exclude_types] def __setitem__(self, key, value): if key in ["bias_atom_e"]: diff --git a/deepmd/pt/model/task/fitting.py b/deepmd/pt/model/task/fitting.py index bd38fca14a..22fb409cad 100644 --- a/deepmd/pt/model/task/fitting.py +++ b/deepmd/pt/model/task/fitting.py @@ -388,6 +388,9 @@ def get_dim_aparam(self) -> int: """Get the number (dimension) of atomic parameters of this atomic model.""" return self.numb_aparam + # make jit happy + exclude_types: List[int] + def get_sel_type(self) -> List[int]: """Get the selected atom types of this model. @@ -395,7 +398,12 @@ def get_sel_type(self) -> List[int]: to the result of the model. If returning an empty list, all atom types are selected. """ - return [] + # make jit happy + sel_type: List[int] = [] + for ii in range(self.ntypes): + if ii not in self.exclude_types: + sel_type.append(ii) + return sel_type def __setitem__(self, key, value): if key in ["bias_atom_e"]: diff --git a/source/tests/common/dpmodel/test_fitting_invar_fitting.py b/source/tests/common/dpmodel/test_fitting_invar_fitting.py index a31439d406..87eeb9e06b 100644 --- a/source/tests/common/dpmodel/test_fitting_invar_fitting.py +++ b/source/tests/common/dpmodel/test_fitting_invar_fitting.py @@ -64,6 +64,10 @@ def test_self_consistency( ret0 = ifn0(dd[0], atype, fparam=ifp, aparam=iap) ret1 = ifn1(dd[0], atype, fparam=ifp, aparam=iap) np.testing.assert_allclose(ret0["energy"], ret1["energy"]) + sel_set = set(ifn0.get_sel_type()) + exclude_set = set(et) + self.assertEqual(sel_set | exclude_set, set(range(self.nt))) + self.assertEqual(sel_set & exclude_set, set()) def test_mask(self): nf, nloc, nnei = self.nlist.shape diff --git a/source/tests/pt/model/test_ener_fitting.py b/source/tests/pt/model/test_ener_fitting.py index a41b4d6b9f..69bd4b42a3 100644 --- a/source/tests/pt/model/test_ener_fitting.py +++ b/source/tests/pt/model/test_ener_fitting.py @@ -95,6 +95,7 @@ def test_consistency( to_numpy_array(ret0["foo"]), to_numpy_array(ret2["foo"]), ) + self.assertEqual(ft0.get_sel_type(), ft1.get_sel_type()) def test_new_old( self,