Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
YibinLiu666 committed Dec 1, 2023
1 parent d0c14de commit 28aadd6
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions test/legacy_test/test_take_along_axis_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,19 +281,22 @@ def test_api_dygraph(self):
paddle.enable_static()

def test_error(self):
paddle.disable_static(self.place[0])
tensorx = paddle.to_tensor([[1, 2, 3], [4, 5, 6]]).astype("float32")
indices = paddle.to_tensor([1]).astype("int32")
# len(arr.shape) != len(indices.shape)
try:
with self.assertRaises(ValueError):
res = paddle.take_along_axis(tensorx, indices, 0, False)
except Exception as error:
self.assertIsInstance(error, ValueError)
indices = paddle.to_tensor([[10]]).astype("int32")
# the element of indices out of range
try:
with self.assertRaises(RuntimeError):
indices = paddle.to_tensor([[100]]).astype("int32")
res = paddle.take_along_axis(tensorx, indices, 0, False)
# the shape of indices doesn't match
with self.assertRaises(RuntimeError):
indices = paddle.to_tensor(
[[1, 0, 0, 0], [1, 0, 0, 0], [1, 0, 0, 0]]
).astype("int32")
res = paddle.take_along_axis(tensorx, indices, 0, False)
except Exception as error:
self.assertIsInstance(error, RuntimeError)


if __name__ == "__main__":
Expand Down

0 comments on commit 28aadd6

Please sign in to comment.