Skip to content

Commit

Permalink
support zero_dim for some prim ops (#54892)
Browse files Browse the repository at this point in the history
  • Loading branch information
Charles-hit authored Jun 27, 2023
1 parent abc1c3d commit f8d0214
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 5 deletions.
2 changes: 2 additions & 0 deletions python/paddle/incubate/autograd/composite_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,8 @@ def squeeze2_composite(x, axis):
axis can only be list, not int
"""
rank = len(x.shape)
if rank == 0:
return [assign(x), None]
if len(axis) == 0:
dims = set(range(rank))
else:
Expand Down
7 changes: 6 additions & 1 deletion test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1385,6 +1385,11 @@ def init_dtype(self):
self.dtype = np.float32


class TestSqrtComp_ZeroDim(TestSqrtComp):
def init_shape(self):
self.shape = []


class TestRsqrt(TestActivation):
def setUp(self):
self.op_type = "rsqrt"
Expand Down Expand Up @@ -2029,7 +2034,7 @@ def init_shape(self):
self.shape = []

def if_enable_cinn(self):
self.enable_cinn = False
pass


class TestLeakyReluAPI(unittest.TestCase):
Expand Down
28 changes: 24 additions & 4 deletions test/legacy_test/test_reduce_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,10 +584,7 @@ def setUp(self):
self.public_python_api = raw_reduce_prod
self.op_type = "reduce_prod"
self.prim_op_type = "prim"
self.inputs = {'X': np.random.random([]).astype("float64")}
self.outputs = {'Out': self.inputs['X'].prod()}
self.attrs = {'dim': [], 'reduce_all': True}

self.init_inputs_and_outputs()
# 0-D tensor doesn't support in cinn
self.enable_cinn = False

Expand All @@ -603,6 +600,29 @@ def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)


class TestProdOp_ZeroDim1(TestProdOp):
def setUp(self):
self.python_api = paddle.prod
self.public_python_api = paddle.prod
self.op_type = "reduce_prod"
self.prim_op_type = "prim"
self.init_inputs_and_outputs()
# 0-D tensor doesn't support in cinn
self.enable_cinn = False

def init_inputs_and_outputs(self):
self.inputs = {'X': np.random.random([100]).astype("float64")}
self.outputs = {'Out': self.inputs['X'].prod()}
self.attrs = {'dim': [], 'reduce_all': True}


class TestProdOp_ZeroDim2(TestProdOp_ZeroDim1):
def init_inputs_and_outputs(self):
self.inputs = {'X': np.random.random([5, 6, 10]).astype("float64")}
self.outputs = {'Out': self.inputs['X'].prod()}
self.attrs = {'dim': [], 'reduce_all': True}


class TestProd6DOp(OpTest):
def setUp(self):
self.op_type = "reduce_prod"
Expand Down
14 changes: 14 additions & 0 deletions test/legacy_test/test_squeeze2_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,20 @@ def init_dtype(self):
self.dtype = np.uint16


class TestSqueezeOp_ZeroDim1(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = ()
self.axes = (0,)
self.new_shape = ()


class TestSqueezeOp_ZeroDim2(TestSqueezeOp):
def init_test_case(self):
self.ori_shape = (1, 1, 1)
self.axes = (0, 1, 2)
self.new_shape = ()


# Correct: No axes input.
class TestSqueezeOp2(TestSqueezeOp):
def setUp(self):
Expand Down

0 comments on commit f8d0214

Please sign in to comment.