Skip to content

Commit

Permalink
test the case of permuted index in different frames. (#3295)
Browse files Browse the repository at this point in the history
Co-authored-by: Han Wang <[email protected]>
  • Loading branch information
wanghan-iapcm and Han Wang authored Feb 19, 2024
1 parent 63bec22 commit 235ff24
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 6 deletions.
20 changes: 18 additions & 2 deletions source/tests/common/dpmodel/case_single_frame_with_nlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def setUp(self):
# nloc == 3, nall == 4
self.nloc = 3
self.nall = 4
self.nf, self.nt = 1, 2
self.nf, self.nt = 2, 2
self.coord_ext = np.array(
[
[0, 0, 0],
Expand All @@ -16,7 +16,7 @@ def setUp(self):
[0, -2, 0],
],
dtype=np.float64,
).reshape([1, self.nall * 3])
).reshape([1, self.nall, 3])
self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall])
# sel = [5, 2]
self.sel = [5, 2]
Expand All @@ -30,3 +30,19 @@ def setUp(self):
).reshape([1, self.nloc, sum(self.sel)])
self.rcut = 0.4
self.rcut_smth = 2.2
# permutations
self.perm = np.array([2, 0, 1, 3], dtype=np.int32)
inv_perm = np.array([1, 2, 0, 3], dtype=np.int32)
# permute the coord and atype
self.coord_ext = np.concatenate(
[self.coord_ext, self.coord_ext[:, self.perm, :]], axis=0
).reshape(self.nf, self.nall * 3)
self.atype_ext = np.concatenate(
[self.atype_ext, self.atype_ext[:, self.perm]], axis=0
)
# permute the nlist
nlist1 = self.nlist[:, self.perm[: self.nloc], :]
mask = nlist1 == -1
nlist1 = inv_perm[nlist1]
nlist1 = np.where(mask, -1, nlist1)
self.nlist = np.concatenate([self.nlist, nlist1], axis=0)
3 changes: 3 additions & 0 deletions source/tests/common/dpmodel/test_exclusion_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def test_build_type_exclude_mask(self):
[1, 1, 1, 1, 1, 0, 1],
[1, 1, 1, 1, 1, 0, 1],
[0, 0, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 0, 1],
[1, 1, 1, 1, 1, 0, 1],
]
).reshape(self.nf, self.nloc, sum(self.sel))
des = PairExcludeMask(self.nt, exclude_types=exclude_types)
Expand Down
9 changes: 7 additions & 2 deletions source/tests/common/dpmodel/test_fitting_invar_fitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,13 @@ def test_mask(self):
# atom index 2 is of type 1 that is excluded
zero_idx = 2
np.testing.assert_allclose(
ret0["energy"][:, zero_idx, :],
np.zeros_like(ret0["energy"][:, zero_idx, :]),
ret0["energy"][0, zero_idx, :],
np.zeros_like(ret0["energy"][0, zero_idx, :]),
)
zero_idx = 0
np.testing.assert_allclose(
ret0["energy"][1, zero_idx, :],
np.zeros_like(ret0["energy"][1, zero_idx, :]),
)

def test_self_exception(
Expand Down
21 changes: 19 additions & 2 deletions source/tests/pt/model/test_env_mat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def setUp(self):
# nloc == 3, nall == 4
self.nloc = 3
self.nall = 4
self.nf, self.nt = 1, 2
self.nf, self.nt = 2, 2
self.coord_ext = np.array(
[
[0, 0, 0],
Expand All @@ -31,7 +31,7 @@ def setUp(self):
[0, -2, 0],
],
dtype=np.float64,
).reshape([1, self.nall * 3])
).reshape([1, self.nall, 3])
self.atype_ext = np.array([0, 0, 1, 0], dtype=int).reshape([1, self.nall])
# sel = [5, 2]
self.sel = [5, 2]
Expand All @@ -45,6 +45,22 @@ def setUp(self):
).reshape([1, self.nloc, sum(self.sel)])
self.rcut = 0.4
self.rcut_smth = 2.2
# permutations
self.perm = np.array([2, 0, 1, 3], dtype=np.int32)
inv_perm = np.array([1, 2, 0, 3], dtype=np.int32)
# permute the coord and atype
self.coord_ext = np.concatenate(
[self.coord_ext, self.coord_ext[:, self.perm, :]], axis=0
).reshape(self.nf, self.nall * 3)
self.atype_ext = np.concatenate(
[self.atype_ext, self.atype_ext[:, self.perm]], axis=0
)
# permute the nlist
nlist1 = self.nlist[:, self.perm[: self.nloc], :]
mask = nlist1 == -1
nlist1 = inv_perm[nlist1]
nlist1 = np.where(mask, -1, nlist1)
self.nlist = np.concatenate([self.nlist, nlist1], axis=0)


class TestCaseSingleFrameWithoutNlist:
Expand Down Expand Up @@ -94,3 +110,4 @@ def test_consistency(
)
np.testing.assert_allclose(mm0, mm1)
np.testing.assert_allclose(ww0, ww1)
np.testing.assert_allclose(mm0[0][self.perm[: self.nloc]], mm0[1])
3 changes: 3 additions & 0 deletions source/tests/pt/model/test_exclusion_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def test_build_type_exclude_mask(self):
[1, 1, 1, 1, 1, 0, 1],
[1, 1, 1, 1, 1, 0, 1],
[0, 0, 1, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 0, 1],
[1, 1, 1, 1, 1, 0, 1],
]
).reshape(self.nf, self.nloc, sum(self.sel))
des = PairExcludeMask(self.nt, exclude_types=exclude_types).to(env.DEVICE)
Expand Down
7 changes: 7 additions & 0 deletions source/tests/pt/model/test_se_e2_a.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ def test_consistency(
atol=atol,
err_msg=err_msg,
)
np.testing.assert_allclose(
rd0.detach().cpu().numpy()[0][self.perm[: self.nloc]],
rd0.detach().cpu().numpy()[1],
rtol=rtol,
atol=atol,
err_msg=err_msg,
)
# dp impl
dd2 = DPDescrptSeA.deserialize(dd0.serialize())
rd2, gr2, _, _, sw2 = dd2.call(
Expand Down

0 comments on commit 235ff24

Please sign in to comment.