Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed May 21, 2024
1 parent 15db5df commit 126295b
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 0 deletions.
8 changes: 8 additions & 0 deletions deepmd/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,3 +787,11 @@ def __getitem__(self, key: str):
if key not in self.dict:
raise KeyError(key)
return self.dict[key]

def __eq__(self, __value: object) -> bool:
if not isinstance(__value, DataRequirementItem):
return False
return self.dict == __value.dict

def __repr__(self) -> str:
return f"DataRequirementItem({self.dict})"

Check warning on line 797 in deepmd/utils/data.py

View check run for this annotation

Codecov / codecov/patch

deepmd/utils/data.py#L797

Added line #L797 was not covered by tests
59 changes: 59 additions & 0 deletions source/tests/tf/test_loss_gf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
from deepmd.tf.loss import (
EnerStdLoss,
)
from deepmd.utils.data import (
DataRequirementItem,
)


class TestLossGf(tf.test.TestCase):
Expand All @@ -26,6 +29,62 @@ def setUp(self):
numb_generalized_coord=2,
)

def test_label_requirements(self):
"""Test label_requirements are expected."""
self.assertCountEqual(
self.loss.label_requirement,
[
DataRequirementItem(
"energy",
1,
atomic=False,
must=False,
high_prec=True,
repeat=1,
),
DataRequirementItem(
"force",
3,
atomic=True,
must=False,
high_prec=False,
repeat=1,
),
DataRequirementItem(
"virial",
9,
atomic=False,
must=False,
high_prec=False,
repeat=1,
),
DataRequirementItem(
"atom_pref",
1,
atomic=True,
must=False,
high_prec=False,
repeat=3,
),
DataRequirementItem(
"atom_ener",
1,
atomic=True,
must=False,
high_prec=False,
repeat=1,
),
DataRequirementItem(
"drdq",
2 * 3,
atomic=True,
must=False,
high_prec=False,
repeat=1,
),
],
)

def test_build_loss(self):
natoms = tf.constant([6, 6])
model_dict = {
Expand Down
3 changes: 3 additions & 0 deletions source/tests/tf/test_model_se_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,9 @@ def test_model(self):
np.testing.assert_almost_equal(f, reff, places)
np.testing.assert_almost_equal(v, refv, places)

# test input requirement for the model
self.assertCountEqual(model.input_requirement, [])

def test_model_atom_ener_type_embedding(self):
"""Test atom ener with type embedding."""
jfile = "water_se_a.json"
Expand Down
9 changes: 9 additions & 0 deletions source/tests/tf/test_model_se_a_aparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from deepmd.tf.model import (
EnerModel,
)
from deepmd.utils.data import (
DataRequirementItem,
)

from .common import (
DataSystem,
Expand Down Expand Up @@ -165,3 +168,9 @@ def test_model(self):
np.testing.assert_almost_equal(e, refe, places)
np.testing.assert_almost_equal(f, reff, places)
np.testing.assert_almost_equal(v, refv, places)

# test input requirement for the model
self.assertCountEqual(
model.input_requirement,
[DataRequirementItem("aparam", 2, atomic=True, must=True, high_prec=False)],
)
13 changes: 13 additions & 0 deletions source/tests/tf/test_model_se_a_fparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
from deepmd.tf.model import (
EnerModel,
)
from deepmd.utils.data import (
DataRequirementItem,
)

from .common import (
DataSystem,
Expand Down Expand Up @@ -166,3 +169,13 @@ def test_model(self):
np.testing.assert_almost_equal(e, refe, places)
np.testing.assert_almost_equal(f, reff, places)
np.testing.assert_almost_equal(v, refv, places)

# test input requirement for the model
self.assertCountEqual(
model.input_requirement,
[
DataRequirementItem(
"fparam", 2, atomic=False, must=True, high_prec=False
)
],
)
9 changes: 9 additions & 0 deletions source/tests/tf/test_pairwise_dprc.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
from deepmd.tf.utils.sess import (
run_sess,
)
from deepmd.utils.data import (
DataRequirementItem,
)

from .common import (
run_dp,
Expand Down Expand Up @@ -523,6 +526,12 @@ def test_model_ener(self):
self.assertAllClose(e[0], 0.189075, 1e-6)
self.assertAllClose(f[0, 0], 0.060047, 1e-6)

# test input requirement for the model
self.assertCountEqual(
model.input_requirement,
[DataRequirementItem("aparam", 1, atomic=True, must=True, high_prec=False)],
)

def test_nloc(self):
jfile = tests_path / "pairwise_dprc.json"
jdata = j_loader(jfile)
Expand Down

0 comments on commit 126295b

Please sign in to comment.