From 126295b49a29368ae2ff0c836c7b7f2cd665e52a Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 21 May 2024 18:53:05 -0400 Subject: [PATCH] add tests Signed-off-by: Jinzhe Zeng --- deepmd/utils/data.py | 8 +++ source/tests/tf/test_loss_gf.py | 59 +++++++++++++++++++++++ source/tests/tf/test_model_se_a.py | 3 ++ source/tests/tf/test_model_se_a_aparam.py | 9 ++++ source/tests/tf/test_model_se_a_fparam.py | 13 +++++ source/tests/tf/test_pairwise_dprc.py | 9 ++++ 6 files changed, 101 insertions(+) diff --git a/deepmd/utils/data.py b/deepmd/utils/data.py index cd0e414b5f..91782d898f 100644 --- a/deepmd/utils/data.py +++ b/deepmd/utils/data.py @@ -787,3 +787,11 @@ def __getitem__(self, key: str): if key not in self.dict: raise KeyError(key) return self.dict[key] + + def __eq__(self, __value: object) -> bool: + if not isinstance(__value, DataRequirementItem): + return False + return self.dict == __value.dict + + def __repr__(self) -> str: + return f"DataRequirementItem({self.dict})" diff --git a/source/tests/tf/test_loss_gf.py b/source/tests/tf/test_loss_gf.py index 78e5404e03..116b98b649 100644 --- a/source/tests/tf/test_loss_gf.py +++ b/source/tests/tf/test_loss_gf.py @@ -5,6 +5,9 @@ from deepmd.tf.loss import ( EnerStdLoss, ) +from deepmd.utils.data import ( + DataRequirementItem, +) class TestLossGf(tf.test.TestCase): @@ -26,6 +29,62 @@ def setUp(self): numb_generalized_coord=2, ) + def test_label_requirements(self): + """Test label_requirements are expected.""" + self.assertCountEqual( + self.loss.label_requirement, + [ + DataRequirementItem( + "energy", + 1, + atomic=False, + must=False, + high_prec=True, + repeat=1, + ), + DataRequirementItem( + "force", + 3, + atomic=True, + must=False, + high_prec=False, + repeat=1, + ), + DataRequirementItem( + "virial", + 9, + atomic=False, + must=False, + high_prec=False, + repeat=1, + ), + DataRequirementItem( + "atom_pref", + 1, + atomic=True, + must=False, + high_prec=False, + repeat=3, + ), + DataRequirementItem( + "atom_ener", + 1, + atomic=True, + must=False, + high_prec=False, + repeat=1, + ), + DataRequirementItem( + "drdq", + 2 * 3, + atomic=True, + must=False, + high_prec=False, + repeat=1, + ), + ], + ) + def test_build_loss(self): natoms = tf.constant([6, 6]) model_dict = { diff --git a/source/tests/tf/test_model_se_a.py b/source/tests/tf/test_model_se_a.py index ad2c1b7ced..039ead3a09 100644 --- a/source/tests/tf/test_model_se_a.py +++ b/source/tests/tf/test_model_se_a.py @@ -265,6 +265,9 @@ def test_model(self): np.testing.assert_almost_equal(f, reff, places) np.testing.assert_almost_equal(v, refv, places) + # test input requirement for the model + self.assertCountEqual(model.input_requirement, []) + def test_model_atom_ener_type_embedding(self): """Test atom ener with type embedding.""" jfile = "water_se_a.json" diff --git a/source/tests/tf/test_model_se_a_aparam.py b/source/tests/tf/test_model_se_a_aparam.py index e44e1c8c9f..2485d1e674 100644 --- a/source/tests/tf/test_model_se_a_aparam.py +++ b/source/tests/tf/test_model_se_a_aparam.py @@ -16,6 +16,9 @@ from deepmd.tf.model import ( EnerModel, ) +from deepmd.utils.data import ( + DataRequirementItem, +) from .common import ( DataSystem, @@ -165,3 +168,9 @@ def test_model(self): np.testing.assert_almost_equal(e, refe, places) np.testing.assert_almost_equal(f, reff, places) np.testing.assert_almost_equal(v, refv, places) + + # test input requirement for the model + self.assertCountEqual( + model.input_requirement, + [DataRequirementItem("aparam", 2, atomic=True, must=True, high_prec=False)], + ) diff --git a/source/tests/tf/test_model_se_a_fparam.py b/source/tests/tf/test_model_se_a_fparam.py index ce31f94488..efcd3f44c8 100644 --- a/source/tests/tf/test_model_se_a_fparam.py +++ b/source/tests/tf/test_model_se_a_fparam.py @@ -16,6 +16,9 @@ from deepmd.tf.model import ( EnerModel, ) +from deepmd.utils.data import ( + DataRequirementItem, +) from .common import ( DataSystem, @@ -166,3 +169,13 @@ def test_model(self): np.testing.assert_almost_equal(e, refe, places) np.testing.assert_almost_equal(f, reff, places) np.testing.assert_almost_equal(v, refv, places) + + # test input requirement for the model + self.assertCountEqual( + model.input_requirement, + [ + DataRequirementItem( + "fparam", 2, atomic=False, must=True, high_prec=False + ) + ], + ) diff --git a/source/tests/tf/test_pairwise_dprc.py b/source/tests/tf/test_pairwise_dprc.py index 38b8d8b775..7a0f28b092 100644 --- a/source/tests/tf/test_pairwise_dprc.py +++ b/source/tests/tf/test_pairwise_dprc.py @@ -34,6 +34,9 @@ from deepmd.tf.utils.sess import ( run_sess, ) +from deepmd.utils.data import ( + DataRequirementItem, +) from .common import ( run_dp, @@ -523,6 +526,12 @@ def test_model_ener(self): self.assertAllClose(e[0], 0.189075, 1e-6) self.assertAllClose(f[0, 0], 0.060047, 1e-6) + # test input requirement for the model + self.assertCountEqual( + model.input_requirement, + [DataRequirementItem("aparam", 1, atomic=True, must=True, high_prec=False)], + ) + def test_nloc(self): jfile = tests_path / "pairwise_dprc.json" jdata = j_loader(jfile)