diff --git a/tests/python/relay/strategy/arm_cpu/test_conv2d.py b/tests/python/relay/strategy/arm_cpu/test_conv2d.py index f4fa250ecfe0..8ef9cb09e648 100644 --- a/tests/python/relay/strategy/arm_cpu/test_conv2d.py +++ b/tests/python/relay/strategy/arm_cpu/test_conv2d.py @@ -120,48 +120,10 @@ class TestConv2d_NCHW_Spatial_Pack(Conv2dTests): schedule_name = parameter("conv2d_nchw_spatial_pack.arm_cpu") -in_dtype = tvm.testing.parameter("float16", "float32") -out_dtype = tvm.testing.parameter("float32") - -batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation = tvm.testing.parameters( - # Pad M, N, K - (1, 1, 1, 1, 1, 1, "SAME", 1), - (1, 1, 3, 15, 1, 1, "SAME", 1), - # Pad M, K - (1, 3, 9, 16, 3, 1, "SAME", 1), - # Pad M, N - (1, 2, 9, 15, 4, 1, "SAME", 1), - # Pad K, N - (1, 7, 4, 15, 3, 1, "SAME", 1), - # Pad M - (1, 2, 9, 16, 4, 1, "SAME", 1), - # Pad K - (1, 7, 4, 16, 3, 1, "SAME", 1), - # Pad N - (1, 2, 4, 15, 4, 1, "SAME", 1), - (1, 2, 4, 20, 1, 1, "SAME", 1), - # Large workloads - (1, 128, 32, 128, 3, 1, "SAME", 1), - (4, 64, 16, 64, 5, 2, "SAME", 1), - (1, 128, 32, 128, 3, 1, "VALID", 1), - (4, 64, 16, 64, 5, 2, "VALID", 1), - (1, 64, 16, 64, 3, 2, (0, 0, 1, 1), 1), - (1, 64, 16, 64, 3, 2, (1, 1, 2, 2), 1), - (1, 64, 16, 64, 5, 2, (3, 3, 2, 2), 1), - (1, 64, 16, 64, 3, 2, (0, 1, 2, 3), 1), - (1, 64, 32, 64, 3, 1, "SAME", 2), - (1, 64, 32, 64, 3, 1, (1, 1, 2, 2), 2), -) - - -@tvm.testing.fixture() -def ref_data( - in_dtype, out_dtype, batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation -): +def ref_data(in_dtype, out_dtype, data_shape, num_filter, kernel_size, stride, padding, dilation): np.random.seed(0) - in_height = in_width = in_size - a_shape = (batch, in_height, in_width, in_channel) - w_shape = (kernel, kernel, in_channel, num_filter) + a_shape = data_shape + w_shape = (kernel_size[0], kernel_size[1], data_shape[3], num_filter) a_np = np.random.uniform(size=a_shape).astype(in_dtype) w_np = np.random.uniform(size=w_shape).astype(in_dtype) @@ -175,9 +137,31 @@ def ref_data( @pytest.mark.skipif( llvm_version_major() < 16, reason="SME is not supported in earlier versions of LLVM" ) +@pytest.mark.parametrize( + "data_shape,kernel_size,num_filter,stride,padding,dilation", + [ + ((1, 1, 1, 1), (3, 3), 1, 1, "SAME", 1), + ((1, 9, 9, 1), (3, 3), 16, 1, "SAME", 1), + ((1, 32, 32, 1), (3, 3), 12, 1, "SAME", 1), + ((1, 32, 10, 3), (3, 3), 16, 1, 0, 1), + ((1, 49, 10, 1), (10, 4), 64, (2, 1), (4, 1, 5, 1), 1), + ((1, 32, 32, 16), (3, 3), 16, 1, (0, 2, 2, 0), 1), + ((1, 32, 32, 16), (3, 4), 16, 1, 0, 1), + ((1, 9, 31, 7), (3, 3), 7, 1, "VALID", 1), + ((1, 32, 32, 16), (5, 5), 16, 1, (0, 2, 2, 0), 2), + ((1, 32, 32, 16), (3, 3), 16, 1, (1, 1, 2, 2), 2), + ((1, 134, 153, 32), (3, 3), 2, (2, 2), "VALID", 1), + ((1, 16, 16, 64), (1, 1), 8, (1, 1), "SAME", 1), + ], +) +@pytest.mark.parametrize("in_dtype,out_dtype", [("float32", "float32"), ("float16", "float32")]) @tvm.testing.requires_aprofile_aem_fvp -def test_conv2d_sme(target, ref_data, in_dtype, out_dtype, stride, padding, dilation): - a_np, w_np, dw_np, b_np = ref_data +def test_conv2d_sme( + target, data_shape, kernel_size, num_filter, stride, padding, dilation, in_dtype, out_dtype +): + a_np, w_np, dw_np, b_np = ref_data( + in_dtype, out_dtype, data_shape, num_filter, kernel_size, stride, padding, dilation + ) kernel_size = get_const_tuple(w_np.shape[:2]) out_channels = w_np.shape[3]