diff --git a/tests/python/frontend/paddlepaddle/test_forward.py b/tests/python/frontend/paddlepaddle/test_forward.py index 56ec3a4e5469..8b696404e2b0 100644 --- a/tests/python/frontend/paddlepaddle/test_forward.py +++ b/tests/python/frontend/paddlepaddle/test_forward.py @@ -1358,7 +1358,10 @@ def slice4(inputs): @tvm.testing.uses_gpu -def test_forward_math_api(): +def run_math_api(func): + api_name = func.__name__.split("_")[-1] + print("func_name:", api_name) + class MathAPI(nn.Layer): def __init__(self, api_name): super(MathAPI, self).__init__() @@ -1371,52 +1374,198 @@ def __init__(self, api_name): def forward(self, inputs): return self.func(inputs) - api_list = [ - "abs", - "acos", - "asin", - "atan", - "ceil", - "cos", - "cosh", - "elu", - "erf", - "exp", - "floor", - "hardshrink", - "hardtanh", - "log_sigmoid", - "log_softmax", - "log", - "log2", - "log10", - "log1p", - "reciprocal", - "relu", - "relu6", - "round", - "rsqrt", - "selu", - "sigmoid", - "sign", - "sin", - "sinh", - "softplus", - "softsign", - "sqrt", - "square", - "swish", - "tan", - "tanh", - ] input_shapes = [[128], [2, 100], [10, 2, 5], [7, 3, 4, 1]] for input_shape in input_shapes: input_data = paddle.rand(input_shape, dtype="float32") - for api_name in api_list: - if api_name in ["log", "log2", "log10", "reciprocal", "sqrt", "rsqrt"]: - # avoid illegal input, all elements should be positive - input_data = paddle.uniform(input_shape, min=0.01, max=0.99) - verify_model(MathAPI(api_name), input_data=input_data) + if api_name in ["log", "log2", "log10", "reciprocal", "sqrt", "rsqrt"]: + # avoid illegal input, all elements should be positive + input_data = paddle.uniform(input_shape, min=0.01, max=0.99) + verify_model(MathAPI(api_name), input_data=input_data) + + +@run_math_api +def test_forward_abs(): + pass + + +@run_math_api +def test_forward_acos(): + pass + + +@run_math_api +def test_forward_abs(): + pass + + +@run_math_api +def test_forward_atan(): + pass + + +@run_math_api +def test_forward_ceil(): + pass + + +@run_math_api +def test_forward_cos(): + pass + + +@run_math_api +def test_forward_cosh(): + pass + + +@run_math_api +def test_forward_elu(): + pass + + +@run_math_api +def test_forward_erf(): + pass + + +@run_math_api +def test_forward_exp(): + pass + + +@run_math_api +def test_forward_floor(): + pass + + +@run_math_api +def test_forward_hardshrink(): + pass + + +@run_math_api +def test_forward_hardtanh(): + pass + + +@run_math_api +def test_forward_log_sigmoid(): + pass + + +@run_math_api +def test_forward_log_softmax(): + pass + + +@run_math_api +def test_forward_log(): + pass + + +@run_math_api +def test_forward_log2(): + pass + + +@run_math_api +def test_forward_log10(): + pass + + +@run_math_api +def test_forward_log1p(): + pass + + +@run_math_api +def test_forward_reciprocal(): + pass + + +@run_math_api +def test_forward_relu(): + pass + + +@run_math_api +def test_forward_round(): + pass + + +@run_math_api +def test_forward_rsqrt(): + pass + + +@run_math_api +def test_forward_selu(): + pass + + +@run_math_api +def test_forward_sigmoid(): + pass + + +@run_math_api +def test_forward_sign(): + pass + + +@run_math_api +def test_forward_sin(): + pass + + +@run_math_api +def test_forward_softplus(): + pass + + +@run_math_api +def test_forward_sqrt(): + pass + + +@run_math_api +def test_forward_square(): + pass + + +@run_math_api +def test_forward_sin(): + pass + + +@run_math_api +def test_forward_softsign(): + pass + + +@run_math_api +def test_forward_sqrt(): + pass + + +@run_math_api +def test_forward_square(): + pass + + +@run_math_api +def test_forward_swish(): + pass + + +@run_math_api +def test_forward_tan(): + pass + + +@run_math_api +def test_forward_tanh(): + pass @tvm.testing.uses_gpu