From f9dd3896bb1338b4c7456962afae358d1574f927 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Thu, 9 Jun 2022 07:31:55 +0300 Subject: [PATCH] [OpenCL] Implement conv2d_winograd algorithm for Adreno (#11543) * Implement conv2d_winograd algorithm for Adreno * Implement gtest for OpenCL texture pool * Implement conv2d_nhwc_winograd for Adreno * Minor refactoring * Fix lint * Apply comments * Apply comments * Fix lint --- CMakeLists.txt | 16 + cmake/modules/LibInfo.cmake | 1 + cmake/modules/OpenCL.cmake | 6 + python/tvm/relay/op/strategy/adreno.py | 99 +++- python/tvm/topi/adreno/__init__.py | 2 + python/tvm/topi/adreno/conv2d_alter_op.py | 218 +++++++- .../tvm/topi/adreno/conv2d_nchw_winograd.py | 128 +++++ .../tvm/topi/adreno/conv2d_nhwc_winograd.py | 128 +++++ .../tvm/topi/adreno/conv2d_winograd_common.py | 512 ++++++++++++++++++ python/tvm/topi/adreno/utils.py | 28 + src/runtime/opencl/texture_pool.cc | 191 ++++--- src/runtime/texture.h | 22 +- src/support/libinfo.cc | 5 + .../opencl/opencl_texture_pool_test.cc | 151 ++++++ tests/cpp-runtime/opencl/run_gtests.cc | 60 ++ tests/python/contrib/test_opencl/conftest.py | 29 + .../contrib/test_opencl/test_run_gtests.py | 55 ++ .../python/relay/test_conv2d_nchw_texture.py | 3 +- .../python/relay/test_conv2d_nhwc_texture.py | 43 ++ tests/python/relay/utils/adreno_utils.py | 1 + 20 files changed, 1597 insertions(+), 101 deletions(-) create mode 100644 python/tvm/topi/adreno/conv2d_nchw_winograd.py create mode 100644 python/tvm/topi/adreno/conv2d_nhwc_winograd.py create mode 100644 python/tvm/topi/adreno/conv2d_winograd_common.py create mode 100644 tests/cpp-runtime/opencl/opencl_texture_pool_test.cc create mode 100644 tests/cpp-runtime/opencl/run_gtests.cc create mode 100644 tests/python/contrib/test_opencl/conftest.py create mode 100644 tests/python/contrib/test_opencl/test_run_gtests.py diff --git a/CMakeLists.txt b/CMakeLists.txt index 5352eddd25987..b4d6e18aad630 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -26,6 +26,7 @@ endif() # Alernatively, use cmake -DOPTION=VALUE through command-line. tvm_option(USE_CUDA "Build with CUDA" OFF) tvm_option(USE_OPENCL "Build with OpenCL" OFF) +tvm_option(USE_OPENCL_GTEST "Path to OpenCL specific gtest version for runtime cpp tests." /path/to/opencl/gtest) tvm_option(USE_VULKAN "Build with Vulkan" OFF) @@ -609,6 +610,18 @@ if(BUILD_FOR_HEXAGON AND DEFINED USE_HEXAGON_GTEST AND EXISTS ${USE_HEXAGON_GTES include_directories("${USE_HEXAGON_GTEST}/include") endif() +if(USE_OPENCL AND DEFINED USE_OPENCL_GTEST AND EXISTS ${USE_OPENCL_GTEST}) + include(FetchContent) + FetchContent_Declare(googletest SOURCE_DIR "${USE_OPENCL_GTEST}") + set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + FetchContent_MakeAvailable(googletest) + target_link_libraries(tvm_runtime PUBLIC gtest) + target_link_libraries(tvm PUBLIC gtest) + include_directories("${USE_OPENCL_GTEST}/include") + include_directories("${USE_OPENCL_GTEST}/googletest/include") + message(STATUS "Found OpenCL gtest at ${USE_OPENCL_GTEST}") +endif() + # Set flags for clang include(cmake/modules/ClangFlags.cmake) set(CRC16_INCLUDE_PATH "3rdparty/libcrc/include") @@ -668,6 +681,9 @@ install(TARGETS tvm_runtime EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_S if(BUILD_FOR_HEXAGON AND DEFINED USE_HEXAGON_GTEST AND EXISTS ${USE_HEXAGON_GTEST}) install(TARGETS gtest EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) endif() +if(USE_OPENCL AND DEFINED USE_OPENCL_GTEST AND EXISTS ${USE_OPENCL_GTEST}) + install(TARGETS gtest EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) +endif() if (INSTALL_DEV) install( diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index 76ddbede8ac06..3e6b3c787f656 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -89,6 +89,7 @@ function(add_lib_info src_file) TVM_INFO_USE_MSVC_MT="${USE_MSVC_MT}" TVM_INFO_USE_NNPACK="${USE_NNPACK}" TVM_INFO_USE_OPENCL="${USE_OPENCL}" + TVM_INFO_USE_OPENCL_GTEST="${USE_OPENCL_GTEST}" TVM_INFO_USE_OPENMP="${USE_OPENMP}" TVM_INFO_USE_PAPI="${USE_PAPI}" TVM_INFO_USE_PROFILER="${USE_PROFILER}" diff --git a/cmake/modules/OpenCL.cmake b/cmake/modules/OpenCL.cmake index 648e83f575d18..430af7e8722c8 100644 --- a/cmake/modules/OpenCL.cmake +++ b/cmake/modules/OpenCL.cmake @@ -55,6 +55,12 @@ if(USE_OPENCL) message(STATUS "Build with OpenCL support") tvm_file_glob(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc) list(APPEND TVM_RUNTIME_LINKER_LIBS ${OpenCL_LIBRARIES}) + + if(DEFINED USE_OPENCL_GTEST AND EXISTS ${USE_OPENCL_GTEST}) + file_glob_append(RUNTIME_OPENCL_SRCS + "${CMAKE_SOURCE_DIR}/tests/cpp-runtime/opencl/*.cc" + ) + endif() list(APPEND RUNTIME_SRCS ${RUNTIME_OPENCL_SRCS}) else() list(APPEND COMPILER_SRCS src/target/opt/build_opencl_off.cc) diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py index 5f34341e135bf..cb43bd1990748 100644 --- a/python/tvm/relay/op/strategy/adreno.py +++ b/python/tvm/relay/op/strategy/adreno.py @@ -28,6 +28,7 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): strategy = _op.OpStrategy() data, kernel = inputs dilation_h, dilation_w = attrs.get_int_tuple("dilation") + stride_h, stride_w = attrs.get_int_tuple("strides") groups = attrs.groups data_layout = attrs.data_layout kernel_layout = attrs.kernel_layout @@ -38,6 +39,28 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): if (data_layout == "NCHW" and kernel_layout == "OIHW") or ( data_layout == "NCHW4c" and kernel_layout == "OIHW4o" ): + if len(kernel.shape) == 4: + _, _, kh, kw = get_const_tuple(kernel.shape) + else: + _, _, kh, kw, _ = get_const_tuple(kernel.shape) + if ( + (2 < kh < 8 and 2 < kw < 8 and kh == kw) + and (stride_h == 1 and stride_w == 1) + and (dilation_h == 1 and dilation_w == 1) + ): + if out_type.dtype == "float16": + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nchw_winograd), + wrap_topi_schedule(topi.adreno.schedule_conv2d_nchw_winograd), + name="conv2d_nchw_winograd.image2d", + plevel=25, + ) + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nchw_winograd_acc32), + wrap_topi_schedule(topi.adreno.schedule_conv2d_nchw_winograd_acc32), + name="conv2d_nchw_winograd_acc32.image2d", + plevel=30, + ) if out_type.dtype == "float16": strategy.add_implementation( wrap_compute_conv2d(topi.adreno.conv2d_nchwc), @@ -48,12 +71,34 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): strategy.add_implementation( wrap_compute_conv2d(topi.adreno.conv2d_nchwc_acc32), wrap_topi_schedule(topi.adreno.schedule_conv2d_nchwc_acc32), - name="conv2d_nchwc_tpack.image2d", + name="conv2d_nchwc_acc32.image2d", plevel=20, ) elif (data_layout == "NHWC" and kernel_layout == "HWIO") or ( data_layout == "NHWC4c" and kernel_layout == "HWIO4o" ): + if len(kernel.shape) == 4: + kh, kw, _, _ = get_const_tuple(kernel.shape) + else: + kh, kw, _, _, _ = get_const_tuple(kernel.shape) + if ( + (2 < kh < 8 and 2 < kw < 8 and kh == kw) + and (stride_h == 1 and stride_w == 1) + and (dilation_h == 1 and dilation_w == 1) + ): + if out_type.dtype == "float16": + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nhwc_winograd), + wrap_topi_schedule(topi.adreno.schedule_conv2d_nhwc_winograd), + name="conv2d_nhwc_winograd.image2d", + plevel=25, + ) + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nhwc_winograd_acc32), + wrap_topi_schedule(topi.adreno.schedule_conv2d_nhwc_winograd_acc32), + name="conv2d_nhwc_winograd_acc32.image2d", + plevel=30, + ) if out_type.dtype == "float16": strategy.add_implementation( wrap_compute_conv2d(topi.adreno.conv2d_nhwc), @@ -153,6 +198,58 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, target): return strategy +@conv2d_winograd_without_weight_transfrom_strategy.register("adreno") +def conv2d_winograd_without_weight_transfrom_strategy_adreno(attrs, inputs, out_type, target): + """conv2d_winograd_without_weight_transfrom adreno strategy""" + dilation = attrs.get_int_tuple("dilation") + groups = attrs.get_int("groups") + layout = attrs.data_layout + assert dilation == (1, 1), "Do not support dilate now" + assert groups == 1, "Do not supoort arbitrary group number" + strategy = _op.OpStrategy() + if layout in ("NCHW", "NCHW4c"): + if out_type.dtype == "float16": + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nchw_winograd_without_weight_transform), + wrap_topi_schedule( + topi.adreno.schedule_conv2d_nchw_winograd_without_weight_transform + ), + name="conv2d_nchw_winograd_without_weight_transform.image2d", + plevel=35, + ) + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nchw_winograd_without_weight_transform_acc32), + wrap_topi_schedule( + topi.adreno.schedule_conv2d_nchw_winograd_without_weight_transform_acc32 + ), + name="conv2d_nchw_winograd_without_weight_transform_acc32.image2d", + plevel=40, + ) + elif layout in ("NHWC", "NHWC4c"): + if out_type.dtype == "float16": + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nhwc_winograd_without_weight_transform), + wrap_topi_schedule( + topi.adreno.schedule_conv2d_nhwc_winograd_without_weight_transform + ), + name="conv2d_nhwc_winograd_without_weight_transform.image2d", + plevel=35, + ) + strategy.add_implementation( + wrap_compute_conv2d(topi.adreno.conv2d_nhwc_winograd_without_weight_transform_acc32), + wrap_topi_schedule( + topi.adreno.schedule_conv2d_nhwc_winograd_without_weight_transform_acc32 + ), + name="conv2d_nhwc_winograd_without_weight_transform_acc32.image2d", + plevel=40, + ) + else: + raise RuntimeError( + "Unsupported conv2d_winograd_without_weight_transfrom layout {}".format(layout) + ) + return strategy + + @schedule_pool.register("adreno") def schedule_pool_adreno(attrs, outs, target): """schedule pooling ops for adreno""" diff --git a/python/tvm/topi/adreno/__init__.py b/python/tvm/topi/adreno/__init__.py index ec11f6b57cb4a..33bb2e0dfa1ee 100644 --- a/python/tvm/topi/adreno/__init__.py +++ b/python/tvm/topi/adreno/__init__.py @@ -24,3 +24,5 @@ from .pooling import * from .conv2d_alter_op import * from .injective import schedule_injective +from .conv2d_nchw_winograd import * +from .conv2d_nhwc_winograd import * diff --git a/python/tvm/topi/adreno/conv2d_alter_op.py b/python/tvm/topi/adreno/conv2d_alter_op.py index e8944093c0f54..16573991e09c5 100644 --- a/python/tvm/topi/adreno/conv2d_alter_op.py +++ b/python/tvm/topi/adreno/conv2d_alter_op.py @@ -25,6 +25,7 @@ from tvm import relay from tvm import autotvm from ..utils import get_const_tuple +from .utils import infer_tile_size from ..nn import conv2d_alter_layout logger = logging.getLogger("topi") @@ -58,7 +59,6 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): kernel_layout = attrs["kernel_layout"] data_tensor, kernel_tensor = tinfos data_dtype = data_tensor.dtype - kernel_dtype = kernel_tensor.dtype out_dtype = out_type.dtype if isinstance(dispatch_ctx, autotvm.task.ApplyGraphBest): @@ -70,12 +70,228 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type): ) workload = autotvm.task.get_workload(outs) if workload is None: + if impl.name.find("winograd") != -1: + if dilation != (1, 1): + logger.warning("Does not support weight pre-transform for dilated convolution.") + return None + + assert (data_layout == "NCHW" and kernel_layout == "OIHW") or ( + data_layout == "NHWC" and kernel_layout == "HWIO" + ) + if data_layout == "NCHW": + N, CI, H, W = get_const_tuple(data_tensor.shape) + CO, _, KH, KW = get_const_tuple(kernel_tensor.shape) + weight = inputs[1] + else: + N, H, W, CI = get_const_tuple(data_tensor.shape) + KH, KW, _, CO = get_const_tuple(kernel_tensor.shape) + weight = relay.layout_transform(inputs[1], "HWIO", "OIHW") + + # Pre-compute weight transformation in winograd + tile_size = infer_tile_size(data_tensor, data_layout) + + # alpha, alpha, CO, CI + weight = relay.nn.contrib_conv2d_winograd_weight_transform( + weight, tile_size=tile_size + ) + new_attrs["tile_size"] = tile_size + new_attrs["channels"] = CO + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs + ) return None cfg = dispatch_ctx.query(target, workload) topi_tmpl = workload[0] + if "conv2d_nchw_winograd" in topi_tmpl: + suffix = "_acc32" if "acc32" in topi_tmpl else "" + wkl_name = "conv2d_nchw_winograd_without_weight_transform" + suffix + ".image2d" + if dilation != (1, 1): + logger.warning("Does not support weight pre-transform for dilated convolution.") + return None + + tile_size = infer_tile_size(data_tensor, data_layout) + if len(data_tensor.shape) == 5: + assert data_layout == "NCHW4c" and kernel_layout == "OIHW4o" + N, CI, H, W, CB = get_const_tuple(data_tensor.shape) + CO, _, KH, KW, COB = get_const_tuple(kernel_tensor.shape) + weight = relay.layout_transform(inputs[1], "OIHW4o", "OIHW") + weight = relay.nn.contrib_conv2d_winograd_weight_transform(weight, tile_size=tile_size) + weight = relay.layout_transform(weight, "HWOI", "HWIO4o") + + new_attrs["tile_size"] = tile_size + new_attrs["channels"] = CO * COB + + new_data = data_tensor + new_weight = te.placeholder( + (KH + tile_size - 1, KW + tile_size - 1, CI * CB, CO, COB), + dtype=kernel_tensor.dtype, + ) + new_workload = autotvm.task.args_to_workload( + [new_data, new_weight, strides, padding, dilation, out_dtype], + wkl_name, + ) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs + ) + + assert data_layout == "NCHW" and kernel_layout == "OIHW" + N, CI, H, W = get_const_tuple(data_tensor.shape) + CO, _, KH, KW = get_const_tuple(kernel_tensor.shape) + + # pre-compute weight transformation in winograd + weight = relay.nn.contrib_conv2d_winograd_weight_transform(inputs[1], tile_size=tile_size) + weight = relay.transpose(weight, axes=[2, 3, 0, 1]) # HWOI -> OIHW + new_attrs["tile_size"] = tile_size + new_attrs["channels"] = CO + + # Store the same config for the altered operator (workload) + new_data = data_tensor + new_weight = te.placeholder( + (KH + tile_size - 1, KW + tile_size - 1, CI, CO), dtype=kernel_tensor.dtype + ) + in_channel_block = CI % 4 + if in_channel_block == 0: + in_channel_block = 4 + num_filter_block = CO % 4 + if num_filter_block == 0: + num_filter_block = 4 + + if in_channel_block != 4 or num_filter_block != 4: + new_workload = autotvm.task.args_to_workload( + [new_data, new_weight, strides, padding, dilation, out_dtype], + wkl_name, + ) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs + ) + + new_attrs["data_layout"] = "NCHW%dc" % in_channel_block + # (oc, ic, h, w) -> (h, w, ic, oc // 4, oc % 4) + new_attrs["kernel_layout"] = "HWIO%do" % num_filter_block + new_attrs["out_layout"] = "NCHW%dc" % num_filter_block + # Store altered operator's config + new_data = te.placeholder( + (N, CI // in_channel_block, H, W, in_channel_block), dtype=data_dtype + ) + new_weight = te.placeholder( + (KH + tile_size - 1, KW + tile_size - 1, CI, CO // num_filter_block, num_filter_block), + dtype=kernel_tensor.dtype, + ) + new_workload = autotvm.task.args_to_workload( + [ + new_data, + new_weight, + strides, + padding, + dilation, + out_dtype, + ], + wkl_name, + ) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs + ) + + if "conv2d_nhwc_winograd" in topi_tmpl: + suffix = "_acc32" if "acc32" in topi_tmpl else "" + wkl_name = "conv2d_nhwc_winograd_without_weight_transform" + suffix + ".image2d" + if dilation != (1, 1): + logger.warning("Does not support weight pre-transform for dilated convolution.") + return None + + tile_size = infer_tile_size(data_tensor, data_layout) + if len(data_tensor.shape) == 5: + assert data_layout == "NHWC4c" and kernel_layout == "HWIO4o" + N, CI, H, W, CB = get_const_tuple(data_tensor.shape) + KH, KW, _, CO, COB = get_const_tuple(kernel_tensor.shape) + weight = relay.layout_transform(inputs[1], "HWIO4o", "OIHW") + weight = relay.nn.contrib_conv2d_winograd_weight_transform(weight, tile_size=tile_size) + weight = relay.layout_transform(weight, "HWOI", "HWIO4o") + + new_attrs["tile_size"] = tile_size + new_attrs["channels"] = CO * COB + + new_data = data_tensor + new_weight = te.placeholder( + (KH + tile_size - 1, KW + tile_size - 1, CI * CB, CO, COB), + dtype=kernel_tensor.dtype, + ) + new_workload = autotvm.task.args_to_workload( + [new_data, new_weight, strides, padding, dilation, out_dtype], + wkl_name, + ) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs + ) + + assert data_layout == "NHWC" and kernel_layout == "HWIO" + N, H, W, CI = get_const_tuple(data_tensor.shape) + KH, KW, _, CO = get_const_tuple(kernel_tensor.shape) + + # pre-compute weight transformation in winograd + weight = relay.layout_transform(inputs[1], "HWIO", "OIHW") + weight = relay.nn.contrib_conv2d_winograd_weight_transform(weight, tile_size=tile_size) + weight = relay.transpose(weight, axes=[0, 1, 3, 2]) # HWOI -> HWIO + new_attrs["tile_size"] = tile_size + new_attrs["channels"] = CO + + # Store the same config for the altered operator (workload) + new_data = data_tensor + new_weight = te.placeholder( + (KH + tile_size - 1, KW + tile_size - 1, CI, CO), dtype=kernel_tensor.dtype + ) + in_channel_block = CI % 4 + if in_channel_block == 0: + in_channel_block = 4 + num_filter_block = CO % 4 + if num_filter_block == 0: + num_filter_block = 4 + + if in_channel_block != 4 or num_filter_block != 4: + new_workload = autotvm.task.args_to_workload( + [new_data, new_weight, strides, padding, dilation, out_dtype], + wkl_name, + ) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs + ) + + new_attrs["data_layout"] = "NHWC%dc" % in_channel_block + # (oc, ic, h, w) -> (h, w, ic, oc // 4, oc % 4) + new_attrs["kernel_layout"] = "HWIO%do" % num_filter_block + new_attrs["out_layout"] = "NHWC%dc" % num_filter_block + # Store altered operator's config + new_data = te.placeholder( + (N, H, W, CI // in_channel_block, in_channel_block), dtype=data_dtype + ) + new_weight = te.placeholder( + (KH + tile_size - 1, KW + tile_size - 1, CI, CO // num_filter_block, num_filter_block), + dtype=kernel_tensor.dtype, + ) + new_workload = autotvm.task.args_to_workload( + [ + new_data, + new_weight, + strides, + padding, + dilation, + out_dtype, + ], + wkl_name, + ) + dispatch_ctx.update(target, new_workload, cfg) + return relay.nn.contrib_conv2d_winograd_without_weight_transform( + inputs[0], weight, **new_attrs + ) + if "conv2d_nchwc" in topi_tmpl: # covers both conv2d_nchwc and depthwise_conv2d_nchwc if data_layout == "NCHW" and kernel_layout == "OIHW": batch, in_channels, in_height, in_width = data_tensor.shape diff --git a/python/tvm/topi/adreno/conv2d_nchw_winograd.py b/python/tvm/topi/adreno/conv2d_nchw_winograd.py new file mode 100644 index 0000000000000..16f7cb8b19d95 --- /dev/null +++ b/python/tvm/topi/adreno/conv2d_nchw_winograd.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument +"""Winograd NCHW template for Adreno backend""" + +import logging +from tvm import autotvm +from .conv2d_winograd_common import conv2d_winograd_comp, schedule_conv2d_winograd_impl + + +logger = logging.getLogger("conv2d_nchw_winograd") + + +@autotvm.register_topi_compute("conv2d_nchw_winograd.image2d") +def conv2d_nchw_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype): + args = {"shared": False, "accumulator": "float16"} + return conv2d_nchw_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, pre_computed=False + ) + + +@autotvm.register_topi_compute("conv2d_nchw_winograd_acc32.image2d") +def conv2d_nchw_winograd_acc32(cfg, data, kernel, strides, padding, dilation, out_dtype): + args = {"shared": False, "accumulator": "float32"} + return conv2d_nchw_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, pre_computed=False + ) + + +@autotvm.register_topi_schedule("conv2d_nchw_winograd.image2d") +def schedule_conv2d_nchw_winograd(cfg, outs): + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc16") + + +@autotvm.register_topi_schedule("conv2d_nchw_winograd_acc32.image2d") +def schedule_conv2d_nchw_winograd_acc32(cfg, outs): + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc32") + + +@autotvm.register_topi_compute("conv2d_nchw_winograd_without_weight_transform.image2d") +def conv2d_nchw_winograd_without_weight_transform( + cfg, data, kernel, strides, padding, dilation, out_dtype +): + args = {"shared": False, "accumulator": "float16"} + return conv2d_nchw_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, pre_computed=True + ) + + +@autotvm.register_topi_compute("conv2d_nchw_winograd_without_weight_transform_acc32.image2d") +def conv2d_nchw_winograd_without_weight_transform_acc32( + cfg, data, kernel, strides, padding, dilation, out_dtype +): + args = {"shared": False, "accumulator": "float32"} + return conv2d_nchw_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, pre_computed=True + ) + + +@autotvm.register_topi_schedule("conv2d_nchw_winograd_without_weight_transform.image2d") +def schedule_conv2d_nchw_winograd_without_weight_transform(cfg, outs): + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc16", pre_computed=True) + + +@autotvm.register_topi_schedule("conv2d_nchw_winograd_without_weight_transform_acc32.image2d") +def schedule_conv2d_nchw_winograd_without_weight_transform_acc32(cfg, outs): + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc32", pre_computed=True) + + +def conv2d_nchw_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args, pre_computed +): + """Compute declaration for winograd + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + data: tvm.te.Tensor + 4-D or 5-D Data tensor with shape NCHW or NCHW4c + + kernel: tvm.te.Tensor + 4-D or 5-D tensor with shape OIHW or OIHW4o + + strides: int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding: int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + out_dtype: str + The output type. This is used for mixed precision. + + args: dict + Dictionary with additional arguments, e.g. accumulator type + + pre_computed: bool + Flag if weights were pre computed if true or the weights should be + computed in runtime + + Returns + ------- + output: tvm.te.Tensor + 4-D or 5-D with shape NCHW or NCHW4c + """ + return conv2d_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args, pre_computed, "NCHW" + ) diff --git a/python/tvm/topi/adreno/conv2d_nhwc_winograd.py b/python/tvm/topi/adreno/conv2d_nhwc_winograd.py new file mode 100644 index 0000000000000..bfe385f210a49 --- /dev/null +++ b/python/tvm/topi/adreno/conv2d_nhwc_winograd.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument +"""Winograd NHWC template for Adreno backend""" + +import logging +from tvm import autotvm +from .conv2d_winograd_common import conv2d_winograd_comp, schedule_conv2d_winograd_impl + + +logger = logging.getLogger("conv2d_nhwc_winograd") + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd.image2d") +def conv2d_nhwc_winograd(cfg, data, kernel, strides, padding, dilation, out_dtype): + args = {"shared": False, "accumulator": "float16"} + return conv2d_nhwc_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, pre_computed=False + ) + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_acc32.image2d") +def conv2d_nhwc_winograd_acc32(cfg, data, kernel, strides, padding, dilation, out_dtype): + args = {"shared": False, "accumulator": "float32"} + return conv2d_nhwc_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, pre_computed=False + ) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd.image2d") +def schedule_conv2d_nhwc_winograd(cfg, outs): + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc16") + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_acc32.image2d") +def schedule_conv2d_nhwc_winograd_acc32(cfg, outs): + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc32") + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_without_weight_transform.image2d") +def conv2d_nhwc_winograd_without_weight_transform( + cfg, data, kernel, strides, padding, dilation, out_dtype +): + args = {"shared": False, "accumulator": "float16"} + return conv2d_nhwc_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, pre_computed=True + ) + + +@autotvm.register_topi_compute("conv2d_nhwc_winograd_without_weight_transform_acc32.image2d") +def conv2d_nhwc_winograd_without_weight_transform_acc32( + cfg, data, kernel, strides, padding, dilation, out_dtype +): + args = {"shared": False, "accumulator": "float32"} + return conv2d_nhwc_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args=args, pre_computed=True + ) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_without_weight_transform.image2d") +def schedule_conv2d_nhwc_winograd_without_weight_transform(cfg, outs): + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc16", pre_computed=True) + + +@autotvm.register_topi_schedule("conv2d_nhwc_winograd_without_weight_transform_acc32.image2d") +def schedule_conv2d_nhwc_winograd_without_weight_transform_acc32(cfg, outs): + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc32", pre_computed=True) + + +def conv2d_nhwc_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args, pre_computed +): + """Compute declaration for winograd + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + data: tvm.te.Tensor + 4-D or 5-D Data tensor with shape NCHW or NCHW4c + + kernel: tvm.te.Tensor + 4-D or 5-D tensor with shape OIHW or OIHW4o + + strides: int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding: int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + out_dtype: str + The output type. This is used for mixed precision. + + args: dict + Dictionary with additional arguments, e.g. accumulator type + + pre_computed: bool + Flag if weights were pre computed if true or the weights should be + computed in runtime + + Returns + ------- + output: tvm.te.Tensor + 4-D or 5-D with shape NCHW or NCHW4c + """ + return conv2d_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args, pre_computed, "NHWC" + ) diff --git a/python/tvm/topi/adreno/conv2d_winograd_common.py b/python/tvm/topi/adreno/conv2d_winograd_common.py new file mode 100644 index 0000000000000..494b691a7f076 --- /dev/null +++ b/python/tvm/topi/adreno/conv2d_winograd_common.py @@ -0,0 +1,512 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument +"""Common Winograd implementation for Adreno backend""" + +import tvm +from tvm import te +from tvm import autotvm + +from tvm.topi import nn +from tvm.topi.utils import get_const_int, get_const_tuple, traverse_inline +from ..nn.winograd_util import winograd_transform_matrices +from .utils import ( + split_to_chunks, + pack_input, + pack_filter, + bind_data_copy, + get_texture_storage, + infer_tile_size, +) + + +def conv2d_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args, pre_computed, layout +): + """Compute declaration for winograd + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + data: tvm.te.Tensor + 4-D or 5-D Data tensor with shape NCHW or NCHW4c + + kernel: tvm.te.Tensor + 4-D or 5-D tensor with shape OIHW or OIHW4o + + strides: int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding: int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + out_dtype: str + The output type. This is used for mixed precision. + + args: dict + Dictionary with additional arguments, e.g. accumulator type + + pre_computed: bool + Flag if weights were pre computed if true or the weights should be + computed in runtime + + layout: str + NHWC or NCHW values are accepted + + Returns + ------- + output: tvm.te.Tensor + 4-D or 5-D with shape NCHW or NCHW4c + """ + assert layout in ("NCHW", "NHWC") + tile_size = infer_tile_size(data, layout) + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides + + convert_from4d = False + if len(data.shape) == 4: + if layout == "NCHW": + N, DCI, H, W = get_const_tuple(data.shape) + else: + N, H, W, DCI = get_const_tuple(data.shape) + if not pre_computed: + if layout == "NCHW": + out_channels, CI, KH, KW = get_const_tuple(kernel.shape) + else: + KH, KW, CI, out_channels = get_const_tuple(kernel.shape) + else: + alpha, _, CI, out_channels = get_const_tuple(kernel.shape) + KH = KW = alpha + 1 - tile_size + + in_channel_chunks, in_channel_block, in_channel_tail = split_to_chunks(CI, 4) + out_channel_chunks, out_channel_block, out_channel_tail = split_to_chunks(out_channels, 4) + if autotvm.GLOBAL_SCOPE.in_tuning is True: + if layout == "NCHW": + dshape = (N, in_channel_chunks, H, W, in_channel_block) + else: + dshape = (N, H, W, in_channel_chunks, in_channel_block) + if not pre_computed: # kernel tensor is raw tensor, do strict check + if layout == "NCHW": + kshape = (out_channel_chunks, CI, KH, KW, out_channel_block) + else: + kshape = (KH, KW, CI, out_channel_chunks, out_channel_block) + else: + kshape = (alpha, alpha, CI, out_channel_chunks, out_channel_block) + data = tvm.te.placeholder(dshape, data.dtype, name="data_placeholder") + kernel = tvm.te.placeholder(kshape, kernel.dtype, name="kernel_placeholder") + else: + convert_from4d = True + data = pack_input( + data, layout, N, in_channel_chunks, in_channel_block, in_channel_tail, H, W + ) + kernel_layout = "OIHW" if layout == "NCHW" else "HWIO" + if not pre_computed: # kernel tensor is raw tensor, do strict check + kernel = pack_filter( + kernel, + kernel_layout, + out_channel_chunks, + out_channel_block, + out_channel_tail, + CI, + in_channel_chunks, + in_channel_block, + in_channel_tail, + KH, + KW, + ) + else: + kernel = pack_filter( + kernel, + "HWIO", + out_channel_chunks, + out_channel_block, + out_channel_tail, + CI, + in_channel_chunks, + in_channel_block, + in_channel_tail, + alpha, + alpha, + ) + if layout == "NCHW": + N, DCI, H, W, CB = get_const_tuple(data.shape) + else: + N, H, W, DCI, CB = get_const_tuple(data.shape) + if not pre_computed: # kernel tensor is raw tensor, do strict check + if layout == "NCHW": + CO, CI, KH, KW, COB = get_const_tuple(kernel.shape) + else: + KH, KW, CI, CO, COB = get_const_tuple(kernel.shape) + alpha = KW + tile_size - 1 + assert HSTR == 1 and WSTR == 1 and KH == KW + else: + alpha, _, CI, CO, COB = get_const_tuple(kernel.shape) + KH = KW = alpha + 1 - tile_size + assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1 + + if isinstance(N, tvm.tir.Any): + N = tvm.te.size_var("n") + + if not isinstance(H, int) or not isinstance(W, int): + raise RuntimeError( + "adreno winograd conv2d doesn't support dynamic input\ + height or width." + ) + + pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW)) + if layout == "NCHW": + data_pad = nn.pad(data, (0, 0, pt, pl, 0), (0, 0, pb, pr, 0), name="data_pad") + else: + data_pad = nn.pad(data, (0, pt, pl, 0, 0), (0, pb, pr, 0, 0), name="data_pad") + + r = KW + m = tile_size + A, B, G = winograd_transform_matrices(m, r, out_dtype) + + H = (H + pt + pb - KH) // HSTR + 1 + W = (W + pl + pr - KW) // WSTR + 1 + nH, nW = (H + m - 1) // m, (W + m - 1) // m + + P = N * nH * nW if isinstance(N, int) else nH * nW + + # transform kernel + if not pre_computed: + r_kh = te.reduce_axis((0, KH), name="r_kh") + r_kw = te.reduce_axis((0, KW), name="r_kw") + if layout == "NCHW": + kernel_pack = te.compute( + (alpha, alpha, CI, CO, COB), + lambda eps, nu, ci, co, cob: te.sum( + kernel[co][ci][r_kh][r_kw][cob] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw] + ), + name="kernel_pack", + ) + else: + kernel_pack = te.compute( + (alpha, alpha, CI, CO, COB), + lambda eps, nu, ci, co, cob: te.sum( + kernel[r_kh][r_kw][ci][co][cob] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw] + ), + name="kernel_pack", + ) + else: + kernel_pack = kernel + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + if layout == "NCHW": + N, CI, H, W, CB = get_const_tuple(data.shape) + else: + N, H, W, CI, CB = get_const_tuple(data.shape) + + # pack input tile + if layout == "NCHW": + input_tile = te.compute( + (alpha, alpha, CI, P, CB), + lambda eps, nu, c, p, cb: data_pad[idxdiv(p, (nH * nW))][c][ + idxmod(idxdiv(p, nW), nH) * m + eps + ][idxmod(p, nW) * m + nu][cb], + name="d", + ) + else: + input_tile = te.compute( + (alpha, alpha, CI, P, CB), + lambda eps, nu, c, p, cb: data_pad[idxdiv(p, (nH * nW))][ + idxmod(idxdiv(p, nW), nH) * m + eps + ][idxmod(p, nW) * m + nu][c][cb], + name="d", + ) + + # transform data + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_a") + data_pack = te.compute( + (P, CI, alpha, alpha, CB), + lambda p, ci, eps, nu, cb: te.sum( + input_tile[r_a][r_b][ci][p][cb] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] + ), + name="data_pack", + ) + + # repack transformed data + data_pack_trans = te.compute( + (alpha, alpha, CI, P, CB), + lambda eps, nu, c, p, cb: data_pack[p][c][eps][nu][cb], + name="data_pack_trans", + ) + + # do batch gemm + ci = te.reduce_axis((0, CI), name="ci") + cb = te.reduce_axis((0, CB), name="cb") + bgemm = te.compute( + (alpha, alpha, CO, P, COB), + lambda eps, nu, co, p, cob: te.sum( + ( + kernel_pack[eps][nu][ci * CB + cb][co][cob] * data_pack_trans[eps][nu][ci][p][cb] + ).astype(args["accumulator"]), + axis=[ci, cb], + ), + name="bgemm", + ) + + # inverse transform + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_a") + inverse = te.compute( + (CO, P, m, m, COB), + lambda co, p, vh, vw, cob: te.sum( + bgemm[r_a][r_b][co][p][cob] * (A[r_a][vh] * A[r_b][vw]).astype(args["accumulator"]), + axis=[r_a, r_b], + ), + name="inverse", + ) + + # output + if layout == "NCHW": + if convert_from4d and autotvm.GLOBAL_SCOPE.in_tuning is False: + output = te.compute( + (N, out_channels, H, W), + lambda n, c, h, w: inverse[c // CB][n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m)][ + idxmod(h, m) + ][idxmod(w, m)][c % CB].astype(out_dtype), + name="output", + tag="cast_from_acc" + args["accumulator"][-2:], + ) + else: + output = te.compute( + (N, CO, H, W, COB), + lambda n, co, h, w, cob: inverse[co][ + n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m) + ][idxmod(h, m)][idxmod(w, m)][cob].astype(out_dtype), + name="output", + tag="cast_from_acc" + args["accumulator"][-2:], + ) + else: + if convert_from4d and autotvm.GLOBAL_SCOPE.in_tuning is False: + output = te.compute( + (N, H, W, out_channels), + lambda n, h, w, c: inverse[c // CB][n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m)][ + idxmod(h, m) + ][idxmod(w, m)][c % CB].astype(out_dtype), + name="output", + tag="cast_from_acc" + args["accumulator"][-2:], + ) + else: + output = te.compute( + (N, H, W, CO, COB), + lambda n, h, w, co, cob: inverse[co][ + n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m) + ][idxmod(h, m)][idxmod(w, m)][cob].astype(out_dtype), + name="output", + tag="cast_from_acc" + args["accumulator"][-2:], + ) + + if isinstance(N, int): + cfg.add_flop(2 * N * CO * COB * H * W * CI * CB * KH * KW) + + return output + + +def schedule_conv2d_winograd_impl(cfg, outs, tag, pre_computed=False): + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == tag: + schedule_conv2d_winograd(cfg, s, op.output(0), pre_computed=pre_computed) + + traverse_inline(s, outs[0].op, _callback) + return s + + +def schedule_conv2d_winograd(cfg, s, output, pre_computed): + """Schedule winograd template""" + inverse = s[output].op.input_tensors[0] + bgemm, A = s[inverse].op.input_tensors + kernel_pack, data_pack_trans = s[bgemm].op.input_tensors + data_pack = s[data_pack_trans].op.input_tensors[0] + input_tile, B = s[data_pack].op.input_tensors + pad_data = s[input_tile].op.input_tensors[0] + + # data transform + s[B].compute_inline() + s[A].compute_inline() + + # probably will improve real topology execution + if autotvm.GLOBAL_SCOPE.in_tuning: + # Padding to texture + AA = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [input_tile]) + bind_data_copy(s[AA]) + + s[input_tile].compute_inline() + + OL = s.cache_write(data_pack, "local") + c, p, eps, nu, cb = s[data_pack].op.axis + fused = s[data_pack].fuse(c, p, eps, nu) + bx, tx = s[data_pack].split(fused, 128) + s[data_pack].vectorize(cb) + s[data_pack].bind(bx, te.thread_axis("blockIdx.x")) + s[data_pack].bind(tx, te.thread_axis("threadIdx.x")) + + _, _, eps, nu, cb = s[OL].op.axis + r_a, r_b = s[OL].op.reduce_axis + s[OL].unroll(eps) + s[OL].unroll(nu) + s[OL].unroll(r_a) + s[OL].unroll(r_b) + s[OL].vectorize(cb) + s[OL].compute_at(s[data_pack], tx) + s[data_pack].set_scope(get_texture_storage(data_pack.shape)) + + s[data_pack_trans].compute_inline() + + # transform kernel + if not pre_computed: + kernel, G = s[kernel_pack].op.input_tensors + eps, nu, ci, co, cob = s[kernel_pack].op.axis + if autotvm.GLOBAL_SCOPE.in_tuning: + # skip this part during tuning to make recrods accurate + # this part will be pre-computed during pre-compute optimization pass + s[G].pragma(s[G].op.axis[0], "debug_skip_region") + s[kernel_pack].pragma(eps, "debug_skip_region") + else: + s[G].compute_inline() + r_a, r_b = s[kernel_pack].op.reduce_axis + for axis in [eps, nu, r_a, r_b]: + s[kernel_pack].unroll(axis) + + fused = s[kernel_pack].fuse(ci, co) + bb, tt = s[kernel_pack].split(fused, 128) + s[kernel_pack].reorder(bb, tt, eps, nu, r_a, r_b, cob) + s[kernel_pack].vectorize(cob) + s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x")) + s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x")) + else: + kernel = kernel_pack + + if isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in kernel.op.tag: + # manage scheduling of datacopy + pack_data = pad_data.op.input_tensors[0] + bind_data_copy(s[pack_data]) + bind_data_copy(s[kernel]) + elif isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + s[pad_data].compute_inline() + + ##### space definition begin ##### + cfg.define_knob("auto_unroll_max_step", [0, 4, 16]) + b1, b2, y, x, cb = s[bgemm].op.axis + rcc = s[bgemm].op.reduce_axis[0] + alpha = get_const_int(b1.dom.extent) + + cfg.define_split( + "tile_y", y, num_outputs=3, filter=lambda entry: entry.size[2] <= 64 and entry.size[1] <= 8 + ) + cfg.define_split( + "tile_x", + x, + num_outputs=3, + filter=lambda entry: entry.size[2] <= 64 and entry.size[1] >= 4 and entry.size[1] <= 8, + ) + cfg.define_split("tile_rc", rcc, num_outputs=2) + # TODO: Uncomment the following lines when multi_filter will be introduced + # cfg.multi_filter( + # filter=lambda entity: entity["tile_y"].size[2] * entity["tile_x"].size[2] in range(32,1024) + # ) + ##### space definition end ##### + + # batch gemm + OL = s.cache_write(bgemm, "local") + if ( + autotvm.GLOBAL_SCOPE.in_tuning + or isinstance(kernel.op, tvm.te.ComputeOp) + and "filter_pack" in kernel.op.tag + ): + BB = s.cache_read(kernel_pack, get_texture_storage(kernel_pack.shape), [OL]) + bind_data_copy(s[BB]) + + by = s[bgemm].fuse(b1, b2, y) + + # tile and bind spatial axes + bgemm_scope, by = s[bgemm].split(by, nparts=1) + by, vy, ty = cfg["tile_y"].apply(s, bgemm, by) + bx, vx, tx = cfg["tile_x"].apply(s, bgemm, x) + s[bgemm].bind(by, te.thread_axis("blockIdx.y")) + s[bgemm].bind(bx, te.thread_axis("blockIdx.x")) + s[bgemm].bind(vy, te.thread_axis("vthread")) + s[bgemm].bind(vx, te.thread_axis("vthread")) + s[bgemm].bind(ty, te.thread_axis("threadIdx.y")) + s[bgemm].bind(tx, te.thread_axis("threadIdx.x")) + s[bgemm].reorder(bgemm_scope, by, bx, vy, vx, ty, tx, cb) + s[bgemm].vectorize(cb) + s[bgemm].set_scope(get_texture_storage(bgemm.shape)) + + # tile reduction axes + s[OL].compute_at(s[bgemm], tx) + b1, b2, y, x, cb = s[OL].op.axis + (rcc, rcb) = s[OL].op.reduce_axis + b = s[OL].fuse(b1, b2) + s[OL].reorder(b, y, x, rcc, rcb, cb) + # s[OL].unroll(rcb) + s[OL].pragma(rcb, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + s[OL].pragma(rcb, "unroll_explicit", True) + s[OL].vectorize(cb) + + # schedule inverse, output and fusion + if output.op in s.outputs: + OL = None + else: + OL = output + s[OL].set_scope("local") + output = s.outputs[0] + + m = alpha - 3 + 1 + if len(s[output].op.axis) == 4: + n, co, h, w = s[output].op.axis + else: + n, co, h, w, _ = s[output].op.axis + ho, wo, hi, wi = s[output].tile(h, w, m, m) + inverse_scope, n = s[output].split(n, nparts=1) + + fused = s[output].fuse(n, co, ho, wo) + bb, tt = s[output].split(fused, 128) + + s[output].bind(bb, te.thread_axis("blockIdx.x")) + s[output].bind(tt, te.thread_axis("threadIdx.x")) + + if OL is not None: + s[OL].compute_at(s[output], tt) + + co, p, vh, vw, cb = s[inverse].op.axis + r_a, r_b = s[inverse].op.reduce_axis + for axis in [vh, vw, r_a, r_b]: + s[inverse].unroll(axis) + s[inverse].vectorize(cb) + s[inverse].compute_at(s[output], tt) + + return s diff --git a/python/tvm/topi/adreno/utils.py b/python/tvm/topi/adreno/utils.py index d56632b49f51f..591d6b5a209a3 100644 --- a/python/tvm/topi/adreno/utils.py +++ b/python/tvm/topi/adreno/utils.py @@ -558,3 +558,31 @@ def get_texture_storage(shape): return "global.texture-nhwc" else: return "global.texture-weight" + + +def infer_tile_size(data, layout): + """Compute the tile size for Winograd algorithm + + Parameters + ---------- + data: tvm.te.Tensor + Data tensor + + layout: string + Layout of data tebsir + NCHW, NCHW4c, NHWC or NHWC4c are acceptable + + Returns + ------- + tile_size : int + Calculated tile size + """ + assert layout in ("NCHW", "NCHW4c", "NHWC", "NHWC4c"), "Incompatible layout" + if layout in ("NCHW", "NCHW4c"): + H = get_const_tuple(data.shape)[2] + else: + H = get_const_tuple(data.shape)[1] + + if H % 8 == 0: + return 4 + return 2 diff --git a/src/runtime/opencl/texture_pool.cc b/src/runtime/opencl/texture_pool.cc index e7f6655c41142..0b9477f2d4ea3 100644 --- a/src/runtime/opencl/texture_pool.cc +++ b/src/runtime/opencl/texture_pool.cc @@ -29,113 +29,112 @@ namespace tvm { namespace runtime { -class TexturePool::Pool { - public: - Pool() = default; - void* Alloc(Device dev, DeviceAPI* device, size_t width, size_t height, DLDataType type_hint) { - Entry e; - e.data = nullptr; - if (free_list_.size() != 0) { - Entry new_mem; - int64_t min_added_size_x = std::numeric_limits::max(); - int64_t min_added_size_y = std::numeric_limits::max(); - int64_t min_wasted_size_x = std::numeric_limits::max(); - int64_t min_wasted_size_y = std::numeric_limits::max(); - std::vector::iterator best_mem; - for (auto it = free_list_.begin(); it != free_list_.end(); ++it) { - if (it->type.code != type_hint.code) { - continue; - } - new_mem.x = std::max(it->x, width); - new_mem.y = std::max(it->y, height); - int64_t added_size_x = new_mem.x - it->x; - int64_t added_size_y = new_mem.y - it->y; - int64_t wasted_size_x = new_mem.x - width; - int64_t wasted_size_y = new_mem.y - height; - // Minimize added size first and wasted size thereafter - if ((min_added_size_x > 0 && added_size_x < min_added_size_x) || - (min_added_size_y > 0 && added_size_y < min_added_size_y) || - (min_added_size_x == added_size_x && wasted_size_x < min_wasted_size_x) || - (min_added_size_y == added_size_y && wasted_size_y < min_wasted_size_y)) { - min_added_size_x = added_size_x; - min_added_size_y = added_size_y; - min_wasted_size_x = wasted_size_x; - min_wasted_size_y = wasted_size_y; - best_mem = it; - } +void* Pool2D::Alloc(Device dev, DeviceAPI* device, size_t width, size_t height, + DLDataType type_hint) { + Entry e; + Entry new_mem; + // Processed several experiments and found that when we are trying to fit + // small texture to too big texture then it may lead to the performance + // degradation. + // Coefficient at 5 looks like robust variant for reusing textures. + const int64_t max_ratio = 5; + e.data = nullptr; + std::vector::iterator best_mem; + if (free_list_.size() != 0) { + int64_t min_added_size_x = std::numeric_limits::max(); + int64_t min_added_size_y = std::numeric_limits::max(); + int64_t min_wasted_size_x = std::numeric_limits::max(); + int64_t min_wasted_size_y = std::numeric_limits::max(); + for (auto it = free_list_.begin(); it != free_list_.end(); ++it) { + if (it->type.code != type_hint.code) { + continue; } - - if (min_added_size_x == 0 && min_added_size_y == 0) { - // use existing block - e = *best_mem; - free_list_.erase(best_mem); - } else if (static_cast(min_added_size_x) <= width || - static_cast(min_added_size_y) <= height) { - // if added size is less or equal to - // what is needed by alloc, then grow entry - device->FreeDataSpace(dev, best_mem->data); - free_list_.erase(best_mem); - new_mem.type = type_hint; - std::vector shape{int64_t(new_mem.y), int64_t(new_mem.x), 4}; - new_mem.data = device->AllocDataSpace(dev, shape.size(), shape.data(), new_mem.type, - Optional("global.texture")); - e = new_mem; + // avoid reusing too small and too big textures + if (width / it->x > max_ratio || it->x / width > max_ratio || height / it->y > max_ratio || + it->y / height > max_ratio) { + continue; + } + int64_t new_width = std::max(it->x, width); + int64_t new_height = std::max(it->y, height); + int64_t added_size_x = new_width - it->x; + int64_t added_size_y = new_height - it->y; + int64_t wasted_size_x = new_width - width; + int64_t wasted_size_y = new_height - height; + // Minimize added size first and wasted size thereafter + if ((min_added_size_x > 0 && added_size_x < min_added_size_x) || + (min_added_size_y > 0 && added_size_y < min_added_size_y) || + (min_added_size_x == added_size_x && wasted_size_x < min_wasted_size_x) || + (min_added_size_y == added_size_y && wasted_size_y < min_wasted_size_y)) { + min_added_size_x = added_size_x; + min_added_size_y = added_size_y; + min_wasted_size_x = wasted_size_x; + min_wasted_size_y = wasted_size_y; + best_mem = it; + new_mem.x = new_width; + new_mem.y = new_height; } } - if (e.data == nullptr) { - // create new block - std::vector shape{int64_t(height), int64_t(width), 4}; - e.data = device->AllocDataSpace(dev, shape.size(), shape.data(), type_hint, - Optional("global.texture")); - e.x = width; - e.y = height; - e.type = type_hint; + if (min_added_size_x == 0 && min_added_size_y == 0) { + // use existing block + e = *best_mem; + free_list_.erase(best_mem); + } else if (static_cast(min_added_size_x) <= width || + static_cast(min_added_size_y) <= height) { + // if added size is less or equal to + // what is needed by alloc, then grow entry + device->FreeDataSpace(dev, best_mem->data); + free_list_.erase(best_mem); + new_mem.type = type_hint; + std::vector shape{int64_t(new_mem.y), int64_t(new_mem.x), 4}; + new_mem.data = device->AllocDataSpace(dev, shape.size(), shape.data(), new_mem.type, + Optional("global.texture")); + e = new_mem; } - - allocated_.push_back(e); - return e.data; } - void Free(void* data) { - Entry e; - if (allocated_.back().data == data) { - // quick path, last allocated. - e = allocated_.back(); - allocated_.pop_back(); - } else { - int index = static_cast(allocated_.size()) - 2; - for (; index >= 0 && allocated_[index].data != data; --index) { - } - ICHECK_GE(index, 0) << "Attempt to free texture that has not been allocated"; - e = allocated_[index]; - allocated_.erase(allocated_.begin() + index); - } - free_list_.push_back(e); + if (e.data == nullptr) { + // create new block + std::vector shape{int64_t(height), int64_t(width), 4}; + e.data = device->AllocDataSpace(dev, shape.size(), shape.data(), type_hint, + Optional("global.texture")); + e.x = width; + e.y = height; + e.type = type_hint; } - // Release all resources immediately - void Release(Device dev, DeviceAPI* device) { - for (auto& e : allocated_) { - device->FreeDataSpace(dev, e.data); - } - for (auto& e : free_list_) { - device->FreeDataSpace(dev, e.data); + allocated_.push_back(e); + return e.data; +} + +void Pool2D::Free(void* data) { + Entry e; + if (allocated_.back().data == data) { + // quick path, last allocated. + e = allocated_.back(); + allocated_.pop_back(); + } else { + int index = static_cast(allocated_.size()) - 2; + for (; index >= 0 && allocated_[index].data != data; --index) { } - allocated_.clear(); - free_list_.clear(); + ICHECK_GE(index, 0) << "Attempt to free texture that has not been allocated"; + e = allocated_[index]; + allocated_.erase(allocated_.begin() + index); } + free_list_.push_back(e); +} - private: - struct Entry { - void* data; - size_t x; - size_t y; - DLDataType type; - }; - std::vector free_list_; - std::vector allocated_; -}; +// Release all resources immediately +void Pool2D::Release(Device dev, DeviceAPI* device) { + for (auto& e : allocated_) { + device->FreeDataSpace(dev, e.data); + } + for (auto& e : free_list_) { + device->FreeDataSpace(dev, e.data); + } + allocated_.clear(); + free_list_.clear(); +} TexturePool::TexturePool(DLDeviceType device_type, DeviceAPI* device) : device_type_(device_type), device_(device) {} @@ -157,7 +156,7 @@ void* TexturePool::AllocTexture(Device dev, size_t width, size_t height, DLDataT array_.resize(dev.device_id + 1, nullptr); } if (array_[dev.device_id] == nullptr) { - array_[dev.device_id] = new Pool(); + array_[dev.device_id] = new Pool2D(); } return array_[dev.device_id]->Alloc(dev, device_, width, height, type_hint); } diff --git a/src/runtime/texture.h b/src/runtime/texture.h index 5f43c8cee8f3f..dc38101f0cd4f 100644 --- a/src/runtime/texture.h +++ b/src/runtime/texture.h @@ -94,6 +94,25 @@ inline bool IsTextureStorage(std::string scope) { return scope.find("texture") != std::string::npos; } +class TVM_DLL Pool2D { + public: + Pool2D() = default; + void* Alloc(Device dev, DeviceAPI* device, size_t width, size_t height, DLDataType type_hint); + void Free(void* data); + // Release all resources immediately + void Release(Device dev, DeviceAPI* device); + + protected: + struct Entry { + void* data; + size_t x; + size_t y; + DLDataType type; + }; + std::vector free_list_; + std::vector allocated_; +}; + /*! * \brief A two dimensional storage pool that recycles temporal workspace * allocations for dynamically allocated texture. See AllocTexture docstring @@ -136,9 +155,8 @@ class TVM_DLL TexturePool { void FreeTexture(Device dev, void* ptr); private: - class Pool; /*! \brief pool of device local array */ - std::vector array_; + std::vector array_; /*! \brief device type this pool support */ DLDeviceType device_type_; /*! \brief The device API */ diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index e6f322885e3a2..4a969dcee8bb9 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -43,6 +43,10 @@ #define TVM_INFO_USE_OPENCL "NOT-FOUND" #endif +#ifndef TVM_INFO_USE_OPENCL_GTEST +#define TVM_INFO_USE_OPENCL_GTEST "NOT-FOUND" +#endif + #ifndef TVM_INFO_USE_VULKAN #define TVM_INFO_USE_VULKAN "NOT-FOUND" #endif @@ -286,6 +290,7 @@ TVM_DLL Map GetLibInfo() { {"USE_MSVC_MT", TVM_INFO_USE_MSVC_MT}, {"USE_NNPACK", TVM_INFO_USE_NNPACK}, {"USE_OPENCL", TVM_INFO_USE_OPENCL}, + {"USE_OPENCL_GTEST", TVM_INFO_USE_OPENCL_GTEST}, {"USE_OPENMP", TVM_INFO_USE_OPENMP}, {"USE_PAPI", TVM_INFO_USE_PAPI}, {"USE_PROFILER", TVM_INFO_USE_PROFILER}, diff --git a/tests/cpp-runtime/opencl/opencl_texture_pool_test.cc b/tests/cpp-runtime/opencl/opencl_texture_pool_test.cc new file mode 100644 index 0000000000000..2d3f43ddce6de --- /dev/null +++ b/tests/cpp-runtime/opencl/opencl_texture_pool_test.cc @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include + +#include "../src/runtime/opencl/opencl_common.h" +#include "../src/runtime/texture.h" + +using namespace tvm::runtime; +using namespace tvm::runtime::cl; + +// PoolWrapper is necessary because in class Pool2D we don't have an access to +// its protected members. In this class we add new methods which allow us to +// get and check internal state of class Pool +class PoolWrapper : public Pool2D { + public: + inline size_t FreeListSize() const { return free_list_.size(); } + inline size_t AllocatedListSize() const { return allocated_.size(); } + inline std::pair FreeListItemSize(size_t idx) const { + return std::make_pair(free_list_[idx].x, free_list_[idx].y); + } + inline std::pair AllocatedListItemSize(size_t idx) const { + return std::make_pair(allocated_[idx].x, allocated_[idx].y); + } +}; + +TEST(OpenCLTexturePool, textures_reallocation_optimal_size) { + OpenCLWorkspace* workspace = OpenCLWorkspace::Global(); + OpenCLThreadEntry* t = workspace->GetThreadEntry(); + PoolWrapper pool; + EXPECT_EQ(pool.AllocatedListSize(), 0); + EXPECT_EQ(pool.FreeListSize(), 0); + + DLDataType type{kDLFloat, 16, 1}; + void* data1 = pool.Alloc(t->device, workspace, 1024, 768, type); + EXPECT_EQ(pool.AllocatedListSize(), 1); + EXPECT_EQ(pool.FreeListSize(), 0); + auto item = pool.AllocatedListItemSize(0); + EXPECT_EQ(item.first, 1024); + EXPECT_EQ(item.second, 768); + + pool.Alloc(t->device, workspace, 64, 12455, type); + EXPECT_EQ(pool.AllocatedListSize(), 2); + EXPECT_EQ(pool.FreeListSize(), 0); + item = pool.AllocatedListItemSize(1); + EXPECT_EQ(item.first, 64); + EXPECT_EQ(item.second, 12455); + + pool.Free(data1); + EXPECT_EQ(pool.AllocatedListSize(), 1); + EXPECT_EQ(pool.FreeListSize(), 1); + item = pool.AllocatedListItemSize(0); + EXPECT_EQ(item.first, 64); + EXPECT_EQ(item.second, 12455); + item = pool.FreeListItemSize(0); + EXPECT_EQ(item.first, 1024); + EXPECT_EQ(item.second, 768); + + pool.Alloc(t->device, workspace, 768, 1024, type); + EXPECT_EQ(pool.AllocatedListSize(), 2); + EXPECT_EQ(pool.FreeListSize(), 0); + item = pool.AllocatedListItemSize(0); + EXPECT_EQ(item.first, 64); + EXPECT_EQ(item.second, 12455); + item = pool.AllocatedListItemSize(1); + EXPECT_EQ(item.first, 1024); + EXPECT_EQ(item.second, 1024); +} + +TEST(OpenCLTexturePool, avoid_reusing_too_big_textures) { + OpenCLWorkspace* workspace = OpenCLWorkspace::Global(); + OpenCLThreadEntry* t = workspace->GetThreadEntry(); + PoolWrapper pool; + EXPECT_EQ(pool.AllocatedListSize(), 0); + EXPECT_EQ(pool.FreeListSize(), 0); + + DLDataType type{kDLFloat, 16, 1}; + void* data1 = pool.Alloc(t->device, workspace, 12455, 64, type); + EXPECT_EQ(pool.AllocatedListSize(), 1); + EXPECT_EQ(pool.FreeListSize(), 0); + auto item = pool.AllocatedListItemSize(0); + EXPECT_EQ(item.first, 12455); + EXPECT_EQ(item.second, 64); + + pool.Free(data1); + EXPECT_EQ(pool.AllocatedListSize(), 0); + EXPECT_EQ(pool.FreeListSize(), 1); + item = pool.FreeListItemSize(0); + EXPECT_EQ(item.first, 12455); + EXPECT_EQ(item.second, 64); + + pool.Alloc(t->device, workspace, 1024, 768, type); + EXPECT_EQ(pool.AllocatedListSize(), 1); + EXPECT_EQ(pool.FreeListSize(), 1); + item = pool.FreeListItemSize(0); + EXPECT_EQ(item.first, 12455); + EXPECT_EQ(item.second, 64); + item = pool.AllocatedListItemSize(0); + EXPECT_EQ(item.first, 1024); + EXPECT_EQ(item.second, 768); +} + +TEST(OpenCLTexturePool, avoid_reusing_too_small_textures) { + OpenCLWorkspace* workspace = OpenCLWorkspace::Global(); + OpenCLThreadEntry* t = workspace->GetThreadEntry(); + PoolWrapper pool; + EXPECT_EQ(pool.AllocatedListSize(), 0); + EXPECT_EQ(pool.FreeListSize(), 0); + + DLDataType type{kDLFloat, 16, 1}; + void* data1 = pool.Alloc(t->device, workspace, 1024, 64, type); + EXPECT_EQ(pool.AllocatedListSize(), 1); + EXPECT_EQ(pool.FreeListSize(), 0); + auto item = pool.AllocatedListItemSize(0); + EXPECT_EQ(item.first, 1024); + EXPECT_EQ(item.second, 64); + + pool.Free(data1); + EXPECT_EQ(pool.AllocatedListSize(), 0); + EXPECT_EQ(pool.FreeListSize(), 1); + item = pool.FreeListItemSize(0); + EXPECT_EQ(item.first, 1024); + EXPECT_EQ(item.second, 64); + + pool.Alloc(t->device, workspace, 12544, 64, type); + EXPECT_EQ(pool.AllocatedListSize(), 1); + EXPECT_EQ(pool.FreeListSize(), 1); + item = pool.FreeListItemSize(0); + EXPECT_EQ(item.first, 1024); + EXPECT_EQ(item.second, 64); + item = pool.AllocatedListItemSize(0); + EXPECT_EQ(item.first, 12544); + EXPECT_EQ(item.second, 64); +} diff --git a/tests/cpp-runtime/opencl/run_gtests.cc b/tests/cpp-runtime/opencl/run_gtests.cc new file mode 100644 index 0000000000000..b16ae3efc74d9 --- /dev/null +++ b/tests/cpp-runtime/opencl/run_gtests.cc @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include + +#include +#include + +#include "../src/support/utils.h" + +namespace tvm { +namespace runtime { +namespace cl { + +TVM_REGISTER_GLOBAL("opencl.run_gtests").set_body([](TVMArgs args, TVMRetValue* rv) { + // gtest args are passed into this packed func as a singular string + // split gtest args using delimiter and build argument vector + std::vector parsed_args = tvm::support::Split(args[0], ' '); + std::vector argv; + + // add executable name + argv.push_back(const_cast("opencl_run_gtests")); + + // add parsed arguments + for (int i = 0; i < parsed_args.size(); ++i) { + argv.push_back(const_cast(parsed_args[i].data())); + } + + // end of parsed arguments + argv.push_back(nullptr); + + // set argument count + int argc = argv.size() - 1; + + // initialize gtest with arguments and run + ::testing::InitGoogleTest(&argc, argv.data()); + *rv = RUN_ALL_TESTS(); +}); + +} // namespace cl +} // namespace runtime +} // namespace tvm diff --git a/tests/python/contrib/test_opencl/conftest.py b/tests/python/contrib/test_opencl/conftest.py new file mode 100644 index 0000000000000..0a8b9e1c631f0 --- /dev/null +++ b/tests/python/contrib/test_opencl/conftest.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" OpenCL testing fixtures used to deduce testing argument + values from testing parameters """ + + +import pytest + +import tvm +import tvm.testing + +pytest_plugins = [ + "tvm.contrib.hexagon.pytest_plugin", +] diff --git a/tests/python/contrib/test_opencl/test_run_gtests.py b/tests/python/contrib/test_opencl/test_run_gtests.py new file mode 100644 index 0000000000000..4afcf7ee8d660 --- /dev/null +++ b/tests/python/contrib/test_opencl/test_run_gtests.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import os +import pytest +import numpy as np + +import tvm +from tvm import rpc + + +# use pytest -sv to observe gtest output +# use --gtest_args to pass arguments to gtest +# for example to run all "foo" tests twice and observe gtest output run +# pytest -sv --gtests_args="--gtest_filter=*foo* --gtest_repeat=2" +@tvm.testing.requires_opencl +def test_run_gtests(gtest_args): + if ( + "TVM_TRACKER_HOST" in os.environ + and "TVM_TRACKER_PORT" in os.environ + and "TVM_TRACKER_KEY" in os.environ + ): + rpc_tracker_host = os.environ["TVM_TRACKER_HOST"] + rpc_tracker_port = os.environ["TVM_TRACKER_PORT"] + rpc_tracker_port = int(rpc_tracker_port) + rpc_key = os.environ["TVM_TRACKER_KEY"] + tracker = rpc.connect_tracker(rpc_tracker_host, rpc_tracker_port) + rpc_connection = tracker.request(rpc_key, priority=0, session_timeout=600) + else: + rpc_connection = rpc.LocalSession() + + try: + func = rpc_connection.get_function("opencl.run_gtests") + except: + print( + "This test requires TVM Runtime to be built with a OpenCL gtest version using OpenCL API cmake flag -DUSE_OPENCL_GTEST=/path/to/opencl/googletest/gtest" + ) + raise + + gtest_error_code = func(gtest_args) + np.testing.assert_equal(gtest_error_code, 0) diff --git a/tests/python/relay/test_conv2d_nchw_texture.py b/tests/python/relay/test_conv2d_nchw_texture.py index a733782e420d9..07356bdf482f8 100644 --- a/tests/python/relay/test_conv2d_nchw_texture.py +++ b/tests/python/relay/test_conv2d_nchw_texture.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. +import re import tvm import numpy as np from tvm import relay @@ -736,4 +737,4 @@ def test_branch_textures(): if __name__ == "__main__": #test_branch_textures() - test_residual_block() \ No newline at end of file + test_residual_block() diff --git a/tests/python/relay/test_conv2d_nhwc_texture.py b/tests/python/relay/test_conv2d_nhwc_texture.py index aa6ece287c4fc..c63d0864f814c 100644 --- a/tests/python/relay/test_conv2d_nhwc_texture.py +++ b/tests/python/relay/test_conv2d_nhwc_texture.py @@ -16,6 +16,7 @@ # under the License. import os +import re import tvm import numpy as np from tvm import relay @@ -553,3 +554,45 @@ def test_conv2d_yolov3_v2_nhwc_3c(): } build_run_compare(mod, params, {"data": input_shape}, dtype, target) + + +@tvm.testing.requires_opencl +def test_conv2d_vgg16_winograd_4d(): + target = "opencl --device=adreno" + dtype = "float16" + + input_shape = (1, 28, 28, 512) + filter_shape = (3, 3, 512, 512) + bias_shape = (1, 1, 1, 512) + A = relay.var("data", shape=input_shape, dtype=dtype) + B = relay.var("weight", shape=filter_shape, dtype=dtype) + bias = relay.var("bias", shape=bias_shape, dtype=dtype) + + conv = relay.nn.conv2d( + A, + B, + data_layout="NHWC", + kernel_layout="HWIO", + padding=[1, 1, 1, 1], + channels=512, + kernel_size=[3, 3], + out_dtype=dtype, + ) + D = relay.op.add(conv, bias) + D = relay.op.nn.relu(D) + + mod = relay.Function([A, B, bias], D) + np.random.seed(0) + initializer = relay.testing.init.Xavier() + filter_data = np.zeros(filter_shape).astype(dtype) + bias_data = np.zeros(bias_shape).astype(dtype) + initializer("weight", filter_data) + initializer("bias", bias_data) + params1 = { + "weight": tvm.nd.array(filter_data), + "bias": tvm.nd.array(bias_data), + } + + graph = build_run_compare(mod, params1, {"data": input_shape}, dtype, target) + matches = re.findall("winograd", graph) + assert len(matches) > 0 diff --git a/tests/python/relay/utils/adreno_utils.py b/tests/python/relay/utils/adreno_utils.py index 7d569d83e10d0..a47085042bb13 100644 --- a/tests/python/relay/utils/adreno_utils.py +++ b/tests/python/relay/utils/adreno_utils.py @@ -126,6 +126,7 @@ def build_run_compare( # print(index, output[index], x) np.testing.assert_allclose(output, ref_output, rtol=1e-1, atol=1e-1) + return graph def gpu_preprocess(tvm_mod):