Skip to content

Commit

Permalink
fix input in test_network.py
Browse files Browse the repository at this point in the history
Signed-off-by: Jinzhe Zeng <[email protected]>
  • Loading branch information
njzjz committed Nov 13, 2024
1 parent e43134d commit 36b087d
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions source/tests/common/dpmodel/test_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

import numpy as np

from deepmd.dpmodel.common import (
get_xp_precision,
)
from deepmd.dpmodel.utils import (
EmbeddingNet,
FittingNet,
Expand Down Expand Up @@ -46,7 +49,9 @@ def test_serialize_deserize(self):
inp_shap = [ni]
if ashp is not None:
inp_shap = ashp + inp_shap
inp = np.arange(np.prod(inp_shap)).reshape(inp_shap)
inp = np.arange(
np.prod(inp_shap), dtype=get_xp_precision(np, prec)
).reshape(inp_shap)
np.testing.assert_allclose(nl0.call(inp), nl1.call(inp))

def test_shape_error(self):
Expand Down Expand Up @@ -168,7 +173,7 @@ def test_embedding_net(self):
resnet_dt=idt,
)
en1 = EmbeddingNet.deserialize(en0.serialize())
inp = np.ones([ni])
inp = np.ones([ni], dtype=get_xp_precision(np, prec))
np.testing.assert_allclose(en0.call(inp), en1.call(inp))


Expand All @@ -191,7 +196,7 @@ def test_fitting_net(self):
bias_out=bo,
)
en1 = FittingNet.deserialize(en0.serialize())
inp = np.ones([ni])
inp = np.ones([ni], dtype=get_xp_precision(np, prec))
en0.call(inp)
en1.call(inp)
np.testing.assert_allclose(en0.call(inp), en1.call(inp))
Expand Down

0 comments on commit 36b087d

Please sign in to comment.