Skip to content

Commit

Permalink
fix xpu test
Browse files Browse the repository at this point in the history
  • Loading branch information
zoooo0820 committed Nov 24, 2023
1 parent 6f5789b commit 76880db
Showing 1 changed file with 26 additions and 33 deletions.
59 changes: 26 additions & 33 deletions test/xpu/test_set_value_op_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def body(i, x):
i = i + 1
return i, x

i = paddle.zeros(shape=(1,), dtype='int32')
i = paddle.zeros(shape=[], dtype='int32')
i, x = paddle.static.nn.while_loop(cond, body, [i, x])

def _call_setitem_static_api(self, x):
Expand All @@ -243,7 +243,7 @@ def body(i, x):
i = i + 1
return i, x

i = paddle.zeros(shape=(1,), dtype='int32')
i = paddle.zeros(shape=[], dtype='int32')
i, x = paddle.static.nn.while_loop(cond, body, [i, x])
return x

Expand Down Expand Up @@ -504,11 +504,11 @@ def set_dtype(self):
self.dtype = self.in_type

def _call_setitem(self, x):
zero = paddle.full([1], 0, dtype="int32")
zero = paddle.full([], 0, dtype="int32")
x[zero] = self.value

def _call_setitem_static_api(self, x):
zero = paddle.full([1], 0, dtype="int32")
zero = paddle.full([], 0, dtype="int32")
x = paddle.static.setitem(x, zero, self.value)
return x

Expand All @@ -517,13 +517,13 @@ def _get_answer(self):

class XPUTestSetValueItemTensor2(XPUTestSetValueItemTensor):
def _call_setitem(self, x):
zero = paddle.full([1], 0, dtype="int32")
two = paddle.full([1], 2, dtype="int64")
zero = paddle.full([], 0, dtype="int32")
two = paddle.full([], 2, dtype="int64")
x[zero:two] = self.value

def _call_setitem_static_api(self, x):
zero = paddle.full([1], 0, dtype="int32")
two = paddle.full([1], 2, dtype="int64")
zero = paddle.full([], 0, dtype="int32")
two = paddle.full([], 2, dtype="int64")
x = paddle.static.setitem(x, slice(zero, two), self.value)
return x

Expand All @@ -532,13 +532,13 @@ def _get_answer(self):

class XPUTestSetValueItemTensor3(XPUTestSetValueItemTensor):
def _call_setitem(self, x):
zero = paddle.full([1], 0, dtype="int32")
two = paddle.full([1], 2, dtype="int64")
zero = paddle.full([], 0, dtype="int32")
two = paddle.full([], 2, dtype="int64")
x[zero:-1, 0:two] = self.value

def _call_setitem_static_api(self, x):
zero = paddle.full([1], 0, dtype="int32")
two = paddle.full([1], 2, dtype="int64")
zero = paddle.full([], 0, dtype="int32")
two = paddle.full([], 2, dtype="int64")
x = paddle.static.setitem(
x, (slice(zero, -1), slice(0, two)), self.value
)
Expand All @@ -549,13 +549,13 @@ def _get_answer(self):

class XPUTestSetValueItemTensor4(XPUTestSetValueItemTensor):
def _call_setitem(self, x):
zero = paddle.full([1], 0, dtype="int32")
two = paddle.full([1], 2, dtype="int64")
zero = paddle.full([], 0, dtype="int32")
two = paddle.full([], 2, dtype="int64")
x[0:-1, zero:2, 0:6:two] = self.value

def _call_setitem_static_api(self, x):
zero = paddle.full([1], 0, dtype="int32")
two = paddle.full([1], 2, dtype="int64")
zero = paddle.full([], 0, dtype="int32")
two = paddle.full([], 2, dtype="int64")
x = paddle.static.setitem(
x, (slice(0, -1), slice(zero, 2), slice(0, 6, two)), self.value
)
Expand All @@ -566,13 +566,13 @@ def _get_answer(self):

class XPUTestSetValueItemTensor5(XPUTestSetValueItemTensor):
def _call_setitem(self, x):
zero = paddle.full([1], 0, dtype="int32")
two = paddle.full([1], 2, dtype="int64")
zero = paddle.full([], 0, dtype="int32")
two = paddle.full([], 2, dtype="int64")
x[zero:, 1:2:two, :] = self.value

def _call_setitem_static_api(self, x):
zero = paddle.full([1], 0, dtype="int32")
two = paddle.full([1], 2, dtype="int64")
zero = paddle.full([], 0, dtype="int32")
two = paddle.full([], 2, dtype="int64")
x = paddle.static.setitem(
x,
(slice(zero, None), slice(1, 2, two), slice(None, None, None)),
Expand All @@ -588,13 +588,13 @@ def set_shape(self):
self.shape = [3, 4, 5]

def _call_setitem(self, x):
minus1 = paddle.full([1], -1, dtype="int32")
zero = paddle.full([1], 0, dtype="int32")
minus1 = paddle.full([], -1, dtype="int32")
zero = paddle.full([], 0, dtype="int32")
x[2:zero:minus1, 0:2, 10:-6:minus1] = self.value

def _call_setitem_static_api(self, x):
minus1 = paddle.full([1], -1, dtype="int32")
zero = paddle.full([1], 0, dtype="int32")
minus1 = paddle.full([], -1, dtype="int32")
zero = paddle.full([], 0, dtype="int32")
x = paddle.static.setitem(
x,
(slice(2, zero, minus1), slice(0, 2), slice(10, -6, minus1)),
Expand Down Expand Up @@ -1090,13 +1090,6 @@ def _ellipsis_error(self):
x[::one] = self.value

def _bool_list_error(self):
with self.assertRaises(TypeError):
x = paddle.ones(shape=self.shape, dtype=self.dtype)
if paddle.in_dynamic_mode():
x[[True, False, 0]] = 0
else:
x = paddle.static.setitem(x, [True, False, 0], 0)

with self.assertRaises(IndexError):
x = paddle.ones(shape=self.shape, dtype=self.dtype)
if paddle.in_dynamic_mode():
Expand Down Expand Up @@ -1642,13 +1635,13 @@ def test_inplace(self):
paddle.seed(100)
a = paddle.rand(shape=[1, 4])
a.stop_gradient = False
b = a[:]
b = a[:] * 1
c = b
b[paddle.zeros([], dtype='int32')] = 1.0

self.assertTrue(id(b) == id(c))
np.testing.assert_array_equal(b.numpy(), c.numpy())
self.assertEqual(b.inplace_version, 0)
self.assertEqual(b.inplace_version, 1)

paddle.enable_static()

Expand Down

0 comments on commit 76880db

Please sign in to comment.