Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend][PaddlePaddle]split test_forward_math_api function #11537

Merged
merged 1 commit into from
Jun 6, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 193 additions & 44 deletions tests/python/frontend/paddlepaddle/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1358,7 +1358,10 @@ def slice4(inputs):


@tvm.testing.uses_gpu
def test_forward_math_api():
def run_math_api(func):
api_name = func.__name__.split("_")[-1]
print("func_name:", api_name)

class MathAPI(nn.Layer):
def __init__(self, api_name):
super(MathAPI, self).__init__()
Expand All @@ -1371,52 +1374,198 @@ def __init__(self, api_name):
def forward(self, inputs):
return self.func(inputs)

api_list = [
"abs",
"acos",
"asin",
"atan",
"ceil",
"cos",
"cosh",
"elu",
"erf",
"exp",
"floor",
"hardshrink",
"hardtanh",
"log_sigmoid",
"log_softmax",
"log",
"log2",
"log10",
"log1p",
"reciprocal",
"relu",
"relu6",
"round",
"rsqrt",
"selu",
"sigmoid",
"sign",
"sin",
"sinh",
"softplus",
"softsign",
"sqrt",
"square",
"swish",
"tan",
"tanh",
]
input_shapes = [[128], [2, 100], [10, 2, 5], [7, 3, 4, 1]]
for input_shape in input_shapes:
input_data = paddle.rand(input_shape, dtype="float32")
for api_name in api_list:
if api_name in ["log", "log2", "log10", "reciprocal", "sqrt", "rsqrt"]:
# avoid illegal input, all elements should be positive
input_data = paddle.uniform(input_shape, min=0.01, max=0.99)
verify_model(MathAPI(api_name), input_data=input_data)
if api_name in ["log", "log2", "log10", "reciprocal", "sqrt", "rsqrt"]:
# avoid illegal input, all elements should be positive
input_data = paddle.uniform(input_shape, min=0.01, max=0.99)
verify_model(MathAPI(api_name), input_data=input_data)


@run_math_api
def test_forward_abs():
pass


@run_math_api
def test_forward_acos():
pass


@run_math_api
def test_forward_abs():
pass


@run_math_api
def test_forward_atan():
pass


@run_math_api
def test_forward_ceil():
pass


@run_math_api
def test_forward_cos():
pass


@run_math_api
def test_forward_cosh():
pass


@run_math_api
def test_forward_elu():
pass


@run_math_api
def test_forward_erf():
pass


@run_math_api
def test_forward_exp():
pass


@run_math_api
def test_forward_floor():
pass


@run_math_api
def test_forward_hardshrink():
pass


@run_math_api
def test_forward_hardtanh():
pass


@run_math_api
def test_forward_log_sigmoid():
pass


@run_math_api
def test_forward_log_softmax():
pass


@run_math_api
def test_forward_log():
pass


@run_math_api
def test_forward_log2():
pass


@run_math_api
def test_forward_log10():
pass


@run_math_api
def test_forward_log1p():
pass


@run_math_api
def test_forward_reciprocal():
pass


@run_math_api
def test_forward_relu():
pass


@run_math_api
def test_forward_round():
pass


@run_math_api
def test_forward_rsqrt():
pass


@run_math_api
def test_forward_selu():
pass


@run_math_api
def test_forward_sigmoid():
pass


@run_math_api
def test_forward_sign():
pass


@run_math_api
def test_forward_sin():
pass


@run_math_api
def test_forward_softplus():
pass


@run_math_api
def test_forward_sqrt():
pass


@run_math_api
def test_forward_square():
pass


@run_math_api
def test_forward_sin():
pass


@run_math_api
def test_forward_softsign():
pass


@run_math_api
def test_forward_sqrt():
pass


@run_math_api
def test_forward_square():
pass


@run_math_api
def test_forward_swish():
pass


@run_math_api
def test_forward_tan():
pass


@run_math_api
def test_forward_tanh():
pass


@tvm.testing.uses_gpu
Expand Down