Skip to content

Commit

Permalink
split test_forward_math_api function (#11537)
Browse files Browse the repository at this point in the history
  • Loading branch information
heliqi authored Jun 6, 2022
1 parent 283542f commit bf4b8f5
Showing 1 changed file with 193 additions and 44 deletions.
237 changes: 193 additions & 44 deletions tests/python/frontend/paddlepaddle/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,7 +1358,10 @@ def slice4(inputs):


@tvm.testing.uses_gpu
def test_forward_math_api():
def run_math_api(func):
api_name = func.__name__.split("_")[-1]
print("func_name:", api_name)

class MathAPI(nn.Layer):
def __init__(self, api_name):
super(MathAPI, self).__init__()
Expand All @@ -1371,52 +1374,198 @@ def __init__(self, api_name):
def forward(self, inputs):
return self.func(inputs)

api_list = [
"abs",
"acos",
"asin",
"atan",
"ceil",
"cos",
"cosh",
"elu",
"erf",
"exp",
"floor",
"hardshrink",
"hardtanh",
"log_sigmoid",
"log_softmax",
"log",
"log2",
"log10",
"log1p",
"reciprocal",
"relu",
"relu6",
"round",
"rsqrt",
"selu",
"sigmoid",
"sign",
"sin",
"sinh",
"softplus",
"softsign",
"sqrt",
"square",
"swish",
"tan",
"tanh",
]
input_shapes = [[128], [2, 100], [10, 2, 5], [7, 3, 4, 1]]
for input_shape in input_shapes:
input_data = paddle.rand(input_shape, dtype="float32")
for api_name in api_list:
if api_name in ["log", "log2", "log10", "reciprocal", "sqrt", "rsqrt"]:
# avoid illegal input, all elements should be positive
input_data = paddle.uniform(input_shape, min=0.01, max=0.99)
verify_model(MathAPI(api_name), input_data=input_data)
if api_name in ["log", "log2", "log10", "reciprocal", "sqrt", "rsqrt"]:
# avoid illegal input, all elements should be positive
input_data = paddle.uniform(input_shape, min=0.01, max=0.99)
verify_model(MathAPI(api_name), input_data=input_data)


@run_math_api
def test_forward_abs():
pass


@run_math_api
def test_forward_acos():
pass


@run_math_api
def test_forward_abs():
pass


@run_math_api
def test_forward_atan():
pass


@run_math_api
def test_forward_ceil():
pass


@run_math_api
def test_forward_cos():
pass


@run_math_api
def test_forward_cosh():
pass


@run_math_api
def test_forward_elu():
pass


@run_math_api
def test_forward_erf():
pass


@run_math_api
def test_forward_exp():
pass


@run_math_api
def test_forward_floor():
pass


@run_math_api
def test_forward_hardshrink():
pass


@run_math_api
def test_forward_hardtanh():
pass


@run_math_api
def test_forward_log_sigmoid():
pass


@run_math_api
def test_forward_log_softmax():
pass


@run_math_api
def test_forward_log():
pass


@run_math_api
def test_forward_log2():
pass


@run_math_api
def test_forward_log10():
pass


@run_math_api
def test_forward_log1p():
pass


@run_math_api
def test_forward_reciprocal():
pass


@run_math_api
def test_forward_relu():
pass


@run_math_api
def test_forward_round():
pass


@run_math_api
def test_forward_rsqrt():
pass


@run_math_api
def test_forward_selu():
pass


@run_math_api
def test_forward_sigmoid():
pass


@run_math_api
def test_forward_sign():
pass


@run_math_api
def test_forward_sin():
pass


@run_math_api
def test_forward_softplus():
pass


@run_math_api
def test_forward_sqrt():
pass


@run_math_api
def test_forward_square():
pass


@run_math_api
def test_forward_sin():
pass


@run_math_api
def test_forward_softsign():
pass


@run_math_api
def test_forward_sqrt():
pass


@run_math_api
def test_forward_square():
pass


@run_math_api
def test_forward_swish():
pass


@run_math_api
def test_forward_tan():
pass


@run_math_api
def test_forward_tanh():
pass


@tvm.testing.uses_gpu
Expand Down

0 comments on commit bf4b8f5

Please sign in to comment.