Skip to content

Commit

Permalink
Add Op UnitTest for expand_dims (PaddlePaddle#1384)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix

* fix

* fix

* fix
  • Loading branch information
enkilee authored and jiahy0825 committed May 25, 2023
1 parent 3a1af2e commit dc7841f
Showing 1 changed file with 98 additions and 44 deletions.
142 changes: 98 additions & 44 deletions python/tests/ops/test_expand_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np
from op_test import OpTest, OpTestTool
from op_test_helper import TestCaseHelper
import paddle
import cinn
from cinn.frontend import *
from cinn.common import *

Expand All @@ -27,20 +25,16 @@
"x86 test will be skipped due to timeout.")
class TestExpandDimsOp(OpTest):
def setUp(self):
print(f"\nRunning {self.__class__.__name__}: {self.case}")
self.init_case()

def init_case(self):
self.inputs = {
"x": np.random.random([
32,
64,
]).astype("float32")
}
self.axes = [0]
self.x_np = self.random(
shape=self.case["x_shape"], dtype=self.case["x_dtype"])

def build_paddle_program(self, target):
x = paddle.to_tensor(self.inputs["x"], stop_gradient=True)
out = paddle.unsqueeze(x, self.axes)
x = paddle.to_tensor(self.x_np, stop_gradient=True)
out = paddle.unsqueeze(x, self.case["axes_shape"])

self.paddle_outputs = [out]

Expand All @@ -49,43 +43,103 @@ def build_paddle_program(self, target):
def build_cinn_program(self, target):
builder = NetBuilder("expand_dims")
x = builder.create_input(
self.nptype2cinntype(self.inputs["x"].dtype),
self.inputs["x"].shape, "x")
out = builder.expand_dims(x, self.axes)
self.nptype2cinntype(self.case["x_dtype"]), self.case["x_shape"],
"x")
out = builder.expand_dims(x, self.case["axes_shape"])

prog = builder.build()
res = self.get_cinn_output(prog, target, [x], [self.inputs["x"]],
[out])
res = self.get_cinn_output(prog, target, [x], [self.x_np], [out])

self.cinn_outputs = res
self.cinn_outputs = [res[0]]

def test_check_results(self):
self.check_outputs_and_grads()


class TestExpandDimsCase1(TestExpandDimsOp):
def init_case(self):
self.inputs = {"x": np.random.random([2, 3, 4]).astype("float32")}
self.axes = [0, 2, 4]


class TestExpandDimsCase2(TestExpandDimsOp):
def init_case(self):
self.inputs = {"x": np.random.random([2, 3, 4]).astype("float32")}
self.axes = [3, 4, 5]


class TestExpandDimsCase3(TestExpandDimsOp):
def init_case(self):
self.inputs = {"x": np.random.random([2, 3, 4]).astype("float16")}
self.axes = [3, 4, 5]


class TestExpandDimsCase4(TestExpandDimsOp):
def init_case(self):
self.inputs = {"x": np.random.random([2, 3, 4]).astype("int64")}
self.axes = [3, 4, 5]
max_relative_error = self.case[
"max_relative_error"] if "max_relative_error" in self.case else 1e-5
self.check_outputs_and_grads(max_relative_error=max_relative_error)


class TestExpandDimsAll(TestCaseHelper):
def init_attrs(self):
self.class_name = "TestExpandDimsOpCase"
self.cls = TestExpandDimsOp
self.inputs = [
{
"x_shape": [1],
"axes_shape": [0],
},
{
"x_shape": [1024],
"axes_shape": [0, 1],
},
{
"x_shape": [32, 64],
"axes_shape": [0, 2],
},
{
"x_shape": [32, 64],
"axes_shape": [0, 1, 2],
},
{
"x_shape": [32, 64, 128],
"axes_shape": [0, 1, 2],
},
{
"x_shape": [32, 64, 128],
"axes_shape": [1, 2, 3],
},
{
"x_shape": [128, 64, 32, 16],
"axes_shape": [0, 1],
},
{
"x_shape": [128, 64, 32, 16],
"axes_shape": [3, 4],
},
{
"x_shape": [16, 8, 4, 2, 1],
"axes_shape": [2],
},
{
"x_shape": [16, 8, 4, 2, 1],
"axes_shape": [5],
},
]
self.dtypes = [
#{
# "x_dtype": "bool",
# "axes_dtype": "int32",
#},
#{
# "x_dtype": "int8",
# "axes_dtype": "int32",
#},
#{
# "x_dtype": "int16",
# "axes_dtype": "int32",
#},
#{
# "x_dtype": "int32",
# "axes_dtype": "int32",
#},
#{
# "x_dtype": "int64",
# "axes_dtype": "int32",
#},
#{
# "x_dtype": "float16",
# "max_relative_error": 1e-3,
# "axes_dtype": "int32",
#},
{
"x_dtype": "float32",
},
#{
# "x_dtype": "float64",
# "axes_dtype": "int32",
#},
]
self.attrs = []


if __name__ == "__main__":
unittest.main()
TestExpandDimsAll().run()

0 comments on commit dc7841f

Please sign in to comment.