From edd94cbd91e5ddb8d3df54e4441963be9a8eea51 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Wed, 8 Jun 2022 12:30:45 +0300 Subject: [PATCH] Apply comments --- .../tvm/topi/adreno/conv2d_nchw_winograd.py | 378 +------------ .../tvm/topi/adreno/conv2d_nhwc_winograd.py | 377 +------------ .../tvm/topi/adreno/conv2d_winograd_common.py | 513 ++++++++++++++++++ python/tvm/topi/adreno/utils.py | 2 +- src/runtime/opencl/texture_pool.cc | 8 +- src/runtime/texture.h | 6 +- .../opencl/opencl_texture_pool_test.cc | 4 +- .../python/relay/test_conv2d_nhwc_texture.py | 2 +- tests/python/relay/utils/adreno_utils.py | 11 - 9 files changed, 536 insertions(+), 765 deletions(-) create mode 100644 python/tvm/topi/adreno/conv2d_winograd_common.py diff --git a/python/tvm/topi/adreno/conv2d_nchw_winograd.py b/python/tvm/topi/adreno/conv2d_nchw_winograd.py index 18d4fe3c4ba9..538fccf9c3e9 100644 --- a/python/tvm/topi/adreno/conv2d_nchw_winograd.py +++ b/python/tvm/topi/adreno/conv2d_nchw_winograd.py @@ -33,6 +33,7 @@ get_texture_storage, infer_tile_size, ) +from .conv2d_winograd_common import conv2d_winograd_comp, schedule_conv2d_winograd_impl logger = logging.getLogger("conv2d_nchw_winograd") @@ -56,12 +57,12 @@ def conv2d_nchw_winograd_acc32(cfg, data, kernel, strides, padding, dilation, ou @autotvm.register_topi_schedule("conv2d_nchw_winograd.image2d") def schedule_conv2d_nchw_winograd(cfg, outs): - return schedule_conv2d_nchw_winograd_impl(cfg, outs, tag="cast_from_acc16") + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc16") @autotvm.register_topi_schedule("conv2d_nchw_winograd_acc32.image2d") def schedule_conv2d_nchw_winograd_acc32(cfg, outs): - return schedule_conv2d_nchw_winograd_impl(cfg, outs, tag="cast_from_acc32") + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc32") @autotvm.register_topi_compute("conv2d_nchw_winograd_without_weight_transform.image2d") @@ -86,24 +87,12 @@ def conv2d_nchw_winograd_without_weight_transform_acc32( @autotvm.register_topi_schedule("conv2d_nchw_winograd_without_weight_transform.image2d") def schedule_conv2d_nchw_winograd_without_weight_transform(cfg, outs): - return schedule_conv2d_nchw_winograd_impl(cfg, outs, tag="cast_from_acc16", pre_computed=True) + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc16", pre_computed=True) @autotvm.register_topi_schedule("conv2d_nchw_winograd_without_weight_transform_acc32.image2d") def schedule_conv2d_nchw_winograd_without_weight_transform_acc32(cfg, outs): - return schedule_conv2d_nchw_winograd_impl(cfg, outs, tag="cast_from_acc32", pre_computed=True) - - -def schedule_conv2d_nchw_winograd_impl(cfg, outs, tag, pre_computed=False): - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - - def _callback(op): - if op.tag == tag: - schedule_conv2d_winograd(cfg, s, op.output(0), pre_computed=pre_computed) - - traverse_inline(s, outs[0].op, _callback) - return s + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc32", pre_computed=True) def conv2d_nchw_winograd_comp( @@ -148,359 +137,4 @@ def conv2d_nchw_winograd_comp( output: tvm.te.Tensor 4-D or 5-D with shape NCHW or NCHW4c """ - tile_size = infer_tile_size(data, "NCHW") - - if isinstance(dilation, int): - dilation_h = dilation_w = dilation - else: - dilation_h, dilation_w = dilation - HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides - - convert_from4d = False - if len(data.shape) == 4: - N, DCI, H, W = get_const_tuple(data.shape) - if not pre_computed: - out_channels, CI, KH, KW = get_const_tuple(kernel.shape) - else: - alpha, _, CI, out_channels = get_const_tuple(kernel.shape) - KH = KW = alpha + 1 - tile_size - - in_channel_chunks, in_channel_block, in_channel_tail = split_to_chunks(CI, 4) - out_channel_chunks, out_channel_block, out_channel_tail = split_to_chunks(out_channels, 4) - if autotvm.GLOBAL_SCOPE.in_tuning is True: - dshape = (N, in_channel_chunks, H, W, in_channel_block) - data = tvm.te.placeholder(dshape, data.dtype, name="data_placeholder") - if not pre_computed: # kernel tensor is raw tensor, do strict check - kshape = (out_channel_chunks, CI, KH, KW, out_channel_block) - kernel = tvm.te.placeholder(kshape, kernel.dtype, name="kernel_placeholder") - else: - kshape = (alpha, alpha, CI, out_channel_chunks, out_channel_block) - kernel = tvm.te.placeholder(kshape, kernel.dtype, name="kernel_placeholder") - else: - convert_from4d = True - data = pack_input( - data, "NCHW", N, in_channel_chunks, in_channel_block, in_channel_tail, H, W - ) - if not pre_computed: # kernel tensor is raw tensor, do strict check - kernel = pack_filter( - kernel, - "OIHW", - out_channel_chunks, - out_channel_block, - out_channel_tail, - CI, - in_channel_chunks, - in_channel_block, - in_channel_tail, - KH, - KW, - ) - else: - kernel = pack_filter( - kernel, - "HWIO", - out_channel_chunks, - out_channel_block, - out_channel_tail, - CI, - in_channel_chunks, - in_channel_block, - in_channel_tail, - alpha, - alpha, - ) - N, DCI, H, W, CB = get_const_tuple(data.shape) - if not pre_computed: # kernel tensor is raw tensor, do strict check - CO, CI, KH, KW, COB = get_const_tuple(kernel.shape) - alpha = KW + tile_size - 1 - assert HSTR == 1 and WSTR == 1 and KH == KW - else: - alpha, _, CI, CO, COB = get_const_tuple(kernel.shape) - KH = KW = alpha + 1 - tile_size - assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1 - - if isinstance(N, tvm.tir.Any): - N = tvm.te.size_var("n") - - if not isinstance(H, int) or not isinstance(W, int): - raise RuntimeError( - "adreno winograd conv2d doesn't support dynamic input\ - height or width." - ) - - pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW)) - data_pad = nn.pad(data, (0, 0, pt, pl, 0), (0, 0, pb, pr, 0), name="data_pad") - - r = KW - m = tile_size - A, B, G = winograd_transform_matrices(m, r, out_dtype) - - H = (H + pt + pb - KH) // HSTR + 1 - W = (W + pl + pr - KW) // WSTR + 1 - nH, nW = (H + m - 1) // m, (W + m - 1) // m - - P = N * nH * nW if isinstance(N, int) else nH * nW - - # transform kernel - if not pre_computed: - r_kh = te.reduce_axis((0, KH), name="r_kh") - r_kw = te.reduce_axis((0, KW), name="r_kw") - kernel_pack = te.compute( - (alpha, alpha, CI, CO, COB), - lambda eps, nu, ci, co, cob: te.sum( - kernel[co][ci][r_kh][r_kw][cob] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw] - ), - name="kernel_pack", - ) - else: - kernel_pack = kernel - - idxdiv = tvm.tir.indexdiv - idxmod = tvm.tir.indexmod - N, CI, H, W, CB = get_const_tuple(data.shape) - - # pack input tile - input_tile = te.compute( - (alpha, alpha, CI, P, CB), - lambda eps, nu, c, p, cb: data_pad[idxdiv(p, (nH * nW))][c][ - idxmod(idxdiv(p, nW), nH) * m + eps - ][idxmod(p, nW) * m + nu][cb], - name="d", - ) - - # transform data - r_a = te.reduce_axis((0, alpha), "r_a") - r_b = te.reduce_axis((0, alpha), "r_a") - data_pack = te.compute( - (P, CI, alpha, alpha, CB), - lambda p, ci, eps, nu, cb: te.sum( - input_tile[r_a][r_b][ci][p][cb] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] - ), - name="data_pack", - ) - - # repack transformed data - data_pack_trans = te.compute( - (alpha, alpha, CI, P, CB), - lambda eps, nu, c, p, cb: data_pack[p][c][eps][nu][cb], - name="data_pack_trans", - ) - - # do batch gemm - ci = te.reduce_axis((0, CI), name="ci") - cb = te.reduce_axis((0, CB), name="cb") - bgemm = te.compute( - (alpha, alpha, CO, P, COB), - lambda eps, nu, co, p, cob: te.sum( - ( - kernel_pack[eps][nu][ci * CB + cb][co][cob] * data_pack_trans[eps][nu][ci][p][cb] - ).astype(args["accumulator"]), - axis=[ci, cb], - ), - name="bgemm", - ) - - # inverse transform - r_a = te.reduce_axis((0, alpha), "r_a") - r_b = te.reduce_axis((0, alpha), "r_a") - inverse = te.compute( - (CO, P, m, m, COB), - lambda co, p, vh, vw, cob: te.sum( - bgemm[r_a][r_b][co][p][cob] * (A[r_a][vh] * A[r_b][vw]).astype(args["accumulator"]), - axis=[r_a, r_b], - ), - name="inverse", - ) - - # output - if convert_from4d and autotvm.GLOBAL_SCOPE.in_tuning is False: - output = te.compute( - (N, out_channels, H, W), - lambda n, c, h, w: inverse[c // CB][n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m)][ - idxmod(h, m) - ][idxmod(w, m)][c % CB].astype(out_dtype), - name="output", - tag="cast_from_acc" + args["accumulator"][-2:], - ) - else: - output = te.compute( - (N, CO, H, W, COB), - lambda n, co, h, w, cob: inverse[co][n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m)][ - idxmod(h, m) - ][idxmod(w, m)][cob].astype(out_dtype), - name="output", - tag="cast_from_acc" + args["accumulator"][-2:], - ) - - if isinstance(N, int): - cfg.add_flop(2 * N * CO * COB * H * W * CI * CB * KH * KW) - - return output - - -def schedule_conv2d_winograd(cfg, s, output, pre_computed): - """Schedule winograd template""" - inverse = s[output].op.input_tensors[0] - bgemm, A = s[inverse].op.input_tensors - kernel_pack, data_pack_trans = s[bgemm].op.input_tensors - data_pack = s[data_pack_trans].op.input_tensors[0] - input_tile, B = s[data_pack].op.input_tensors - pad_data = s[input_tile].op.input_tensors[0] - - # data transform - s[B].compute_inline() - s[A].compute_inline() - - # probably will improve real topology execution - if autotvm.GLOBAL_SCOPE.in_tuning: - # Padding to texture - AA = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [input_tile]) - bind_data_copy(s[AA]) - - s[input_tile].compute_inline() - - OL = s.cache_write(data_pack, "local") - c, p, eps, nu, cb = s[data_pack].op.axis - fused = s[data_pack].fuse(c, p, eps, nu) - bx, tx = s[data_pack].split(fused, 128) - s[data_pack].vectorize(cb) - s[data_pack].bind(bx, te.thread_axis("blockIdx.x")) - s[data_pack].bind(tx, te.thread_axis("threadIdx.x")) - - c, p, eps, nu, cb = s[OL].op.axis - r_a, r_b = s[OL].op.reduce_axis - s[OL].unroll(eps) - s[OL].unroll(nu) - s[OL].unroll(r_a) - s[OL].unroll(r_b) - s[OL].vectorize(cb) - s[OL].compute_at(s[data_pack], tx) - s[data_pack].set_scope(get_texture_storage(data_pack.shape)) - - s[data_pack_trans].compute_inline() - - # transform kernel - if not pre_computed: - kernel, G = s[kernel_pack].op.input_tensors - eps, nu, ci, co, cob = s[kernel_pack].op.axis - if autotvm.GLOBAL_SCOPE.in_tuning: - # skip this part during tuning to make recrods accurate - # this part will be pre-computed during pre-compute optimization pass - s[G].pragma(s[G].op.axis[0], "debug_skip_region") - s[kernel_pack].pragma(eps, "debug_skip_region") - else: - s[G].compute_inline() - r_a, r_b = s[kernel_pack].op.reduce_axis - for axis in [eps, nu, r_a, r_b]: - s[kernel_pack].unroll(axis) - - fused = s[kernel_pack].fuse(ci, co) - bb, tt = s[kernel_pack].split(fused, 128) - s[kernel_pack].reorder(bb, tt, eps, nu, r_a, r_b, cob) - s[kernel_pack].vectorize(cob) - s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x")) - s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x")) - else: - kernel = kernel_pack - - if isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in kernel.op.tag: - # manage scheduling of datacopy - pack_data = pad_data.op.input_tensors[0] - bind_data_copy(s[pack_data]) - bind_data_copy(s[kernel]) - elif isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: - s[kernel].compute_inline() - s[pad_data].compute_inline() - - ##### space definition begin ##### - cfg.define_knob("auto_unroll_max_step", [0, 4, 16]) - b1, b2, y, x, cb = s[bgemm].op.axis - rcc = s[bgemm].op.reduce_axis[0] - alpha = get_const_int(b1.dom.extent) - - cfg.define_split( - "tile_y", y, num_outputs=3, filter=lambda entry: entry.size[2] <= 64 and entry.size[1] <= 8 - ) - cfg.define_split( - "tile_x", - x, - num_outputs=3, - filter=lambda entry: entry.size[2] <= 64 and entry.size[1] >= 4 and entry.size[1] <= 8, - ) - cfg.define_split("tile_rc", rcc, num_outputs=2) - # TODO: Uncomment the following lines when multi_filter will be introduced - # cfg.multi_filter( - # filter=lambda entity: entity["tile_y"].size[2] * entity["tile_x"].size[2] in range(32,1024) - # ) - ##### space definition end ##### - - # batch gemm - OL = s.cache_write(bgemm, "local") - if ( - autotvm.GLOBAL_SCOPE.in_tuning - or isinstance(kernel.op, tvm.te.ComputeOp) - and "filter_pack" in kernel.op.tag - ): - BB = s.cache_read(kernel_pack, get_texture_storage(kernel_pack.shape), [OL]) - bind_data_copy(s[BB]) - - by = s[bgemm].fuse(b1, b2, y) - - # tile and bind spatial axes - bgemm_scope, by = s[bgemm].split(by, nparts=1) - by, vy, ty = cfg["tile_y"].apply(s, bgemm, by) - bx, vx, tx = cfg["tile_x"].apply(s, bgemm, x) - s[bgemm].bind(by, te.thread_axis("blockIdx.y")) - s[bgemm].bind(bx, te.thread_axis("blockIdx.x")) - s[bgemm].bind(vy, te.thread_axis("vthread")) - s[bgemm].bind(vx, te.thread_axis("vthread")) - s[bgemm].bind(ty, te.thread_axis("threadIdx.y")) - s[bgemm].bind(tx, te.thread_axis("threadIdx.x")) - s[bgemm].reorder(bgemm_scope, by, bx, vy, vx, ty, tx, cb) - s[bgemm].vectorize(cb) - s[bgemm].set_scope(get_texture_storage(bgemm.shape)) - - # tile reduction axes - s[OL].compute_at(s[bgemm], tx) - b1, b2, y, x, cb = s[OL].op.axis - (rcc, rcb) = s[OL].op.reduce_axis - b = s[OL].fuse(b1, b2) - s[OL].reorder(b, y, x, rcc, rcb, cb) - # s[OL].unroll(rcb) - s[OL].pragma(rcb, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) - s[OL].pragma(rcb, "unroll_explicit", True) - s[OL].vectorize(cb) - - # schedule inverse, output and fusion - if output.op in s.outputs: - OL = None - else: - OL = output - s[OL].set_scope("local") - output = s.outputs[0] - - m = alpha - 3 + 1 - if len(s[output].op.axis) == 4: - n, co, h, w = s[output].op.axis - else: - n, co, h, w, _ = s[output].op.axis - ho, wo, hi, wi = s[output].tile(h, w, m, m) - inverse_scope, n = s[output].split(n, nparts=1) - - fused = s[output].fuse(n, co, ho, wo) - bb, tt = s[output].split(fused, 128) - - s[output].bind(bb, te.thread_axis("blockIdx.x")) - s[output].bind(tt, te.thread_axis("threadIdx.x")) - - if OL is not None: - s[OL].compute_at(s[output], tt) - - co, p, vh, vw, cb = s[inverse].op.axis - r_a, r_b = s[inverse].op.reduce_axis - for axis in [vh, vw, r_a, r_b]: - s[inverse].unroll(axis) - s[inverse].vectorize(cb) - s[inverse].compute_at(s[output], tt) - - return s + return conv2d_winograd_comp(cfg, data, kernel, strides, padding, dilation, out_dtype, args, pre_computed, "NCHW") diff --git a/python/tvm/topi/adreno/conv2d_nhwc_winograd.py b/python/tvm/topi/adreno/conv2d_nhwc_winograd.py index 478eeb006ec8..f3850fbec171 100644 --- a/python/tvm/topi/adreno/conv2d_nhwc_winograd.py +++ b/python/tvm/topi/adreno/conv2d_nhwc_winograd.py @@ -33,6 +33,7 @@ get_texture_storage, infer_tile_size, ) +from .conv2d_winograd_common import conv2d_winograd_comp, schedule_conv2d_winograd_impl logger = logging.getLogger("conv2d_nhwc_winograd") @@ -56,12 +57,12 @@ def conv2d_nhwc_winograd_acc32(cfg, data, kernel, strides, padding, dilation, ou @autotvm.register_topi_schedule("conv2d_nhwc_winograd.image2d") def schedule_conv2d_nhwc_winograd(cfg, outs): - return schedule_conv2d_nhwc_winograd_impl(cfg, outs, tag="cast_from_acc16") + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc16") @autotvm.register_topi_schedule("conv2d_nhwc_winograd_acc32.image2d") def schedule_conv2d_nhwc_winograd_acc32(cfg, outs): - return schedule_conv2d_nhwc_winograd_impl(cfg, outs, tag="cast_from_acc32") + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc32") @autotvm.register_topi_compute("conv2d_nhwc_winograd_without_weight_transform.image2d") @@ -86,24 +87,12 @@ def conv2d_nhwc_winograd_without_weight_transform_acc32( @autotvm.register_topi_schedule("conv2d_nhwc_winograd_without_weight_transform.image2d") def schedule_conv2d_nhwc_winograd_without_weight_transform(cfg, outs): - return schedule_conv2d_nhwc_winograd_impl(cfg, outs, tag="cast_from_acc16", pre_computed=True) + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc16", pre_computed=True) @autotvm.register_topi_schedule("conv2d_nhwc_winograd_without_weight_transform_acc32.image2d") def schedule_conv2d_nhwc_winograd_without_weight_transform_acc32(cfg, outs): - return schedule_conv2d_nhwc_winograd_impl(cfg, outs, tag="cast_from_acc32", pre_computed=True) - - -def schedule_conv2d_nhwc_winograd_impl(cfg, outs, tag, pre_computed=False): - outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs - s = te.create_schedule([x.op for x in outs]) - - def _callback(op): - if op.tag == tag: - schedule_conv2d_winograd(cfg, s, op.output(0), pre_computed=pre_computed) - - traverse_inline(s, outs[0].op, _callback) - return s + return schedule_conv2d_winograd_impl(cfg, outs, tag="cast_from_acc32", pre_computed=True) def conv2d_nhwc_winograd_comp( @@ -148,358 +137,4 @@ def conv2d_nhwc_winograd_comp( output: tvm.te.Tensor 4-D or 5-D with shape NCHW or NCHW4c """ - tile_size = infer_tile_size(data, "NHWC") - - if isinstance(dilation, int): - dilation_h = dilation_w = dilation - else: - dilation_h, dilation_w = dilation - HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides - - convert_from4d = False - if len(data.shape) == 4: - N, H, W, DCI = get_const_tuple(data.shape) - if not pre_computed: - KH, KW, CI, out_channels = get_const_tuple(kernel.shape) - else: - alpha, _, CI, out_channels = get_const_tuple(kernel.shape) - KH = KW = alpha + 1 - tile_size - - in_channel_chunks, in_channel_block, in_channel_tail = split_to_chunks(CI, 4) - out_channel_chunks, out_channel_block, out_channel_tail = split_to_chunks(out_channels, 4) - if autotvm.GLOBAL_SCOPE.in_tuning is True: - dshape = (N, H, W, in_channel_chunks, in_channel_block) - data = tvm.te.placeholder(dshape, data.dtype, name="data_placeholder") - if not pre_computed: # kernel tensor is raw tensor, do strict check - kshape = (KH, KW, CI, out_channel_chunks, out_channel_block) - else: - kshape = (alpha, alpha, CI, out_channel_chunks, out_channel_block) - kernel = tvm.te.placeholder(kshape, kernel.dtype, name="kernel_placeholder") - else: - convert_from4d = True - data = pack_input( - data, "NHWC", N, in_channel_chunks, in_channel_block, in_channel_tail, H, W - ) - if not pre_computed: # kernel tensor is raw tensor, do strict check - kernel = pack_filter( - kernel, - "HWIO", - out_channel_chunks, - out_channel_block, - out_channel_tail, - CI, - in_channel_chunks, - in_channel_block, - in_channel_tail, - KH, - KW, - ) - else: - kernel = pack_filter( - kernel, - "HWIO", - out_channel_chunks, - out_channel_block, - out_channel_tail, - CI, - in_channel_chunks, - in_channel_block, - in_channel_tail, - alpha, - alpha, - ) - N, H, W, DCI, CB = get_const_tuple(data.shape) - if not pre_computed: # kernel tensor is raw tensor, do strict check - KH, KW, CI, CO, COB = get_const_tuple(kernel.shape) - alpha = KW + tile_size - 1 - assert HSTR == 1 and WSTR == 1 and KH == KW - else: - alpha, _, CI, CO, COB = get_const_tuple(kernel.shape) - KH = KW = alpha + 1 - tile_size - assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1 - - if isinstance(N, tvm.tir.Any): - N = tvm.te.size_var("n") - - if not isinstance(H, int) or not isinstance(W, int): - raise RuntimeError( - "adreno winograd conv2d doesn't support dynamic input\ - height or width." - ) - - pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW)) - data_pad = nn.pad(data, (0, pt, pl, 0, 0), (0, pb, pr, 0, 0), name="data_pad") - - r = KW - m = tile_size - A, B, G = winograd_transform_matrices(m, r, out_dtype) - - H = (H + pt + pb - KH) // HSTR + 1 - W = (W + pl + pr - KW) // WSTR + 1 - nH, nW = (H + m - 1) // m, (W + m - 1) // m - - P = N * nH * nW if isinstance(N, int) else nH * nW - - # transform kernel - if not pre_computed: - r_kh = te.reduce_axis((0, KH), name="r_kh") - r_kw = te.reduce_axis((0, KW), name="r_kw") - kernel_pack = te.compute( - (alpha, alpha, CI, CO, COB), - lambda eps, nu, ci, co, cob: te.sum( - kernel[r_kh][r_kw][ci][co][cob] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw] - ), - name="kernel_pack", - ) - else: - kernel_pack = kernel - - idxdiv = tvm.tir.indexdiv - idxmod = tvm.tir.indexmod - N, H, W, CI, CB = get_const_tuple(data.shape) - - # pack input tile - input_tile = te.compute( - (alpha, alpha, CI, P, CB), - lambda eps, nu, c, p, cb: data_pad[idxdiv(p, (nH * nW))][ - idxmod(idxdiv(p, nW), nH) * m + eps - ][idxmod(p, nW) * m + nu][c][cb], - name="d", - ) - - # transform data - r_a = te.reduce_axis((0, alpha), "r_a") - r_b = te.reduce_axis((0, alpha), "r_a") - data_pack = te.compute( - (P, CI, alpha, alpha, CB), - lambda p, ci, eps, nu, cb: te.sum( - input_tile[r_a][r_b][ci][p][cb] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] - ), - name="data_pack", - ) - - # repack transformed data - data_pack_trans = te.compute( - (alpha, alpha, CI, P, CB), - lambda eps, nu, c, p, cb: data_pack[p][c][eps][nu][cb], - name="data_pack_trans", - ) - - # do batch gemm - ci = te.reduce_axis((0, CI), name="ci") - cb = te.reduce_axis((0, CB), name="cb") - bgemm = te.compute( - (alpha, alpha, CO, P, COB), - lambda eps, nu, co, p, cob: te.sum( - ( - kernel_pack[eps][nu][ci * CB + cb][co][cob] * data_pack_trans[eps][nu][ci][p][cb] - ).astype(args["accumulator"]), - axis=[ci, cb], - ), - name="bgemm", - ) - - # inverse transform - r_a = te.reduce_axis((0, alpha), "r_a") - r_b = te.reduce_axis((0, alpha), "r_a") - inverse = te.compute( - (CO, P, m, m, COB), - lambda co, p, vh, vw, cob: te.sum( - bgemm[r_a][r_b][co][p][cob] * (A[r_a][vh] * A[r_b][vw]).astype(args["accumulator"]), - axis=[r_a, r_b], - ), - name="inverse", - ) - - # output - if convert_from4d and autotvm.GLOBAL_SCOPE.in_tuning is False: - output = te.compute( - (N, H, W, out_channels), - lambda n, h, w, c: inverse[c // CB][n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m)][ - idxmod(h, m) - ][idxmod(w, m)][c % CB].astype(out_dtype), - name="output", - tag="cast_from_acc" + args["accumulator"][-2:], - ) - else: - output = te.compute( - (N, H, W, CO, COB), - lambda n, h, w, co, cob: inverse[co][n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m)][ - idxmod(h, m) - ][idxmod(w, m)][cob].astype(out_dtype), - name="output", - tag="cast_from_acc" + args["accumulator"][-2:], - ) - - if isinstance(N, int): - cfg.add_flop(2 * N * CO * COB * H * W * CI * CB * KH * KW) - - return output - - -def schedule_conv2d_winograd(cfg, s, output, pre_computed): - """Schedule winograd template""" - inverse = s[output].op.input_tensors[0] - bgemm, A = s[inverse].op.input_tensors - kernel_pack, data_pack_trans = s[bgemm].op.input_tensors - data_pack = s[data_pack_trans].op.input_tensors[0] - input_tile, B = s[data_pack].op.input_tensors - pad_data = s[input_tile].op.input_tensors[0] - - # data transform - s[B].compute_inline() - s[A].compute_inline() - - # probably will improve real topology execution - if autotvm.GLOBAL_SCOPE.in_tuning: - # Padding to texture - AA = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [input_tile]) - bind_data_copy(s[AA]) - - s[input_tile].compute_inline() - - OL = s.cache_write(data_pack, "local") - p, c, eps, nu, cb = s[data_pack].op.axis - fused = s[data_pack].fuse(p, c, eps, nu) - bx, tx = s[data_pack].split(fused, 128) - s[data_pack].vectorize(cb) - s[data_pack].bind(bx, te.thread_axis("blockIdx.x")) - s[data_pack].bind(tx, te.thread_axis("threadIdx.x")) - - _, _, eps, nu, cb = s[OL].op.axis - r_a, r_b = s[OL].op.reduce_axis - s[OL].unroll(eps) - s[OL].unroll(nu) - s[OL].unroll(r_a) - s[OL].unroll(r_b) - s[OL].vectorize(cb) - s[OL].compute_at(s[data_pack], tx) - s[data_pack].set_scope(get_texture_storage(data_pack.shape)) - - s[data_pack_trans].compute_inline() - - # transform kernel - if not pre_computed: - kernel, G = s[kernel_pack].op.input_tensors - eps, nu, ci, co, cob = s[kernel_pack].op.axis - if autotvm.GLOBAL_SCOPE.in_tuning: - # skip this part during tuning to make recrods accurate - # this part will be pre-computed during pre-compute optimization pass - s[G].pragma(s[G].op.axis[0], "debug_skip_region") - s[kernel_pack].pragma(eps, "debug_skip_region") - else: - s[G].compute_inline() - r_a, r_b = s[kernel_pack].op.reduce_axis - for axis in [eps, nu, r_a, r_b]: - s[kernel_pack].unroll(axis) - - fused = s[kernel_pack].fuse(ci, co) - bb, tt = s[kernel_pack].split(fused, 128) - s[kernel_pack].reorder(bb, tt, eps, nu, r_a, r_b, cob) - s[kernel_pack].vectorize(cob) - s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x")) - s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x")) - else: - kernel = kernel_pack - - if isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in kernel.op.tag: - # manage scheduling of datacopy - pack_data = pad_data.op.input_tensors[0] - bind_data_copy(s[pack_data]) - bind_data_copy(s[kernel]) - elif isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: - s[kernel].compute_inline() - s[pad_data].compute_inline() - - ##### space definition begin ##### - cfg.define_knob("auto_unroll_max_step", [0, 4, 16]) - b1, b2, y, x, cb = s[bgemm].op.axis - rcc = s[bgemm].op.reduce_axis[0] - alpha = get_const_int(b1.dom.extent) - - cfg.define_split( - "tile_y", y, num_outputs=3, filter=lambda entry: entry.size[2] <= 64 and entry.size[1] <= 8 - ) - cfg.define_split( - "tile_x", - x, - num_outputs=3, - filter=lambda entry: entry.size[2] <= 64 and entry.size[1] >= 4 and entry.size[1] <= 8, - ) - cfg.define_split("tile_rc", rcc, num_outputs=2) - # TODO: Uncomment the following lines when multi_filter will be introduced - # cfg.multi_filter( - # filter=lambda entity: entity["tile_y"].size[2] * entity["tile_x"].size[2] in range(32,1024) - # ) - ##### space definition end ##### - - # batch gemm - OL = s.cache_write(bgemm, "local") - if ( - autotvm.GLOBAL_SCOPE.in_tuning - or isinstance(kernel.op, tvm.te.ComputeOp) - and "filter_pack" in kernel.op.tag - ): - BB = s.cache_read(kernel_pack, get_texture_storage(kernel_pack.shape), [OL]) - bind_data_copy(s[BB]) - - by = s[bgemm].fuse(b1, b2, y) - - # tile and bind spatial axes - bgemm_scope, by = s[bgemm].split(by, nparts=1) - by, vy, ty = cfg["tile_y"].apply(s, bgemm, by) - bx, vx, tx = cfg["tile_x"].apply(s, bgemm, x) - s[bgemm].bind(by, te.thread_axis("blockIdx.y")) - s[bgemm].bind(bx, te.thread_axis("blockIdx.x")) - s[bgemm].bind(vy, te.thread_axis("vthread")) - s[bgemm].bind(vx, te.thread_axis("vthread")) - s[bgemm].bind(ty, te.thread_axis("threadIdx.y")) - s[bgemm].bind(tx, te.thread_axis("threadIdx.x")) - s[bgemm].reorder(bgemm_scope, by, bx, vy, vx, ty, tx, cb) - s[bgemm].vectorize(cb) - s[bgemm].set_scope(get_texture_storage(bgemm.shape)) - - # tile reduction axes - s[OL].compute_at(s[bgemm], tx) - b1, b2, y, x, cb = s[OL].op.axis - (rcc, rcb) = s[OL].op.reduce_axis - b = s[OL].fuse(b1, b2) - s[OL].reorder(b, y, x, rcc, rcb, cb) - # s[OL].unroll(rcb) - s[OL].pragma(rcb, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) - s[OL].pragma(rcb, "unroll_explicit", True) - s[OL].vectorize(cb) - - # schedule inverse, output and fusion - if output.op in s.outputs: - OL = None - else: - OL = output - s[OL].set_scope("local") - output = s.outputs[0] - - m = alpha - 3 + 1 - if len(s[output].op.axis) == 4: - n, co, h, w = s[output].op.axis - else: - n, co, h, w, _ = s[output].op.axis - ho, wo, hi, wi = s[output].tile(h, w, m, m) - inverse_scope, n = s[output].split(n, nparts=1) - - fused = s[output].fuse(n, co, ho, wo) - bb, tt = s[output].split(fused, 128) - - s[output].bind(bb, te.thread_axis("blockIdx.x")) - s[output].bind(tt, te.thread_axis("threadIdx.x")) - - if OL is not None: - s[OL].compute_at(s[output], tt) - - co, p, vh, vw, cb = s[inverse].op.axis - r_a, r_b = s[inverse].op.reduce_axis - for axis in [vh, vw, r_a, r_b]: - s[inverse].unroll(axis) - s[inverse].vectorize(cb) - s[inverse].compute_at(s[output], tt) - - return s + return conv2d_winograd_comp(cfg, data, kernel, strides, padding, dilation, out_dtype, args, pre_computed, "NHWC") diff --git a/python/tvm/topi/adreno/conv2d_winograd_common.py b/python/tvm/topi/adreno/conv2d_winograd_common.py new file mode 100644 index 000000000000..1b10a8cc57e2 --- /dev/null +++ b/python/tvm/topi/adreno/conv2d_winograd_common.py @@ -0,0 +1,513 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name,unused-variable,unused-argument +"""Common Winograd implementation for Adreno backend""" + +import logging +import tvm +from tvm import te +from tvm import autotvm + +from tvm.topi import nn +from tvm.topi.utils import get_const_int, get_const_tuple, traverse_inline +from ..nn.winograd_util import winograd_transform_matrices +from .utils import ( + split_to_chunks, + pack_input, + pack_filter, + bind_data_copy, + get_texture_storage, + infer_tile_size, +) + + +def conv2d_winograd_comp( + cfg, data, kernel, strides, padding, dilation, out_dtype, args, pre_computed, layout +): + """Compute declaration for winograd + + Parameters + ---------- + cfg: ConfigEntity + The config for this template + + data: tvm.te.Tensor + 4-D or 5-D Data tensor with shape NCHW or NCHW4c + + kernel: tvm.te.Tensor + 4-D or 5-D tensor with shape OIHW or OIHW4o + + strides: int or a list/tuple of two ints + stride size, or [stride_height, stride_width] + + padding: int or a list/tuple of 2 or 4 ints + padding size, or + [pad_height, pad_width] for 2 ints, or + [pad_top, pad_left, pad_bottom, pad_right] for 4 ints + + dilation: int or a list/tuple of two ints + dilation size, or [dilation_height, dilation_width] + + out_dtype: str + The output type. This is used for mixed precision. + + args: dict + Dictionary with additional arguments, e.g. accumulator type + + pre_computed: bool + Flag if weights were pre computed if true or the weights should be + computed in runtime + + layout: str + NHWC or NCHW values are accepted + + Returns + ------- + output: tvm.te.Tensor + 4-D or 5-D with shape NCHW or NCHW4c + """ + assert layout == "NCHW" or layout == "NHWC" + tile_size = infer_tile_size(data, layout) + + if isinstance(dilation, int): + dilation_h = dilation_w = dilation + else: + dilation_h, dilation_w = dilation + HSTR, WSTR = (strides, strides) if isinstance(strides, int) else strides + + convert_from4d = False + if len(data.shape) == 4: + if layout == "NCHW": + N, DCI, H, W = get_const_tuple(data.shape) + else: + N, H, W, DCI = get_const_tuple(data.shape) + if not pre_computed: + if layout == "NCHW": + out_channels, CI, KH, KW = get_const_tuple(kernel.shape) + else: + KH, KW, CI, out_channels = get_const_tuple(kernel.shape) + else: + alpha, _, CI, out_channels = get_const_tuple(kernel.shape) + KH = KW = alpha + 1 - tile_size + + in_channel_chunks, in_channel_block, in_channel_tail = split_to_chunks(CI, 4) + out_channel_chunks, out_channel_block, out_channel_tail = split_to_chunks(out_channels, 4) + if autotvm.GLOBAL_SCOPE.in_tuning is True: + if layout == "NCHW": + dshape = (N, in_channel_chunks, H, W, in_channel_block) + else: + dshape = (N, H, W, in_channel_chunks, in_channel_block) + if not pre_computed: # kernel tensor is raw tensor, do strict check + if layout == "NCHW": + kshape = (out_channel_chunks, CI, KH, KW, out_channel_block) + else: + kshape = (KH, KW, CI, out_channel_chunks, out_channel_block) + else: + kshape = (alpha, alpha, CI, out_channel_chunks, out_channel_block) + data = tvm.te.placeholder(dshape, data.dtype, name="data_placeholder") + kernel = tvm.te.placeholder(kshape, kernel.dtype, name="kernel_placeholder") + else: + convert_from4d = True + data = pack_input( + data, layout, N, in_channel_chunks, in_channel_block, in_channel_tail, H, W + ) + kernel_layout = "OIHW" if layout == "NCHW" else "HWIO" + if not pre_computed: # kernel tensor is raw tensor, do strict check + kernel = pack_filter( + kernel, + kernel_layout, + out_channel_chunks, + out_channel_block, + out_channel_tail, + CI, + in_channel_chunks, + in_channel_block, + in_channel_tail, + KH, + KW, + ) + else: + kernel = pack_filter( + kernel, + "HWIO", + out_channel_chunks, + out_channel_block, + out_channel_tail, + CI, + in_channel_chunks, + in_channel_block, + in_channel_tail, + alpha, + alpha, + ) + if layout == "NCHW": + N, DCI, H, W, CB = get_const_tuple(data.shape) + else: + N, H, W, DCI, CB = get_const_tuple(data.shape) + if not pre_computed: # kernel tensor is raw tensor, do strict check + if layout == "NCHW": + CO, CI, KH, KW, COB = get_const_tuple(kernel.shape) + else: + KH, KW, CI, CO, COB = get_const_tuple(kernel.shape) + alpha = KW + tile_size - 1 + assert HSTR == 1 and WSTR == 1 and KH == KW + else: + alpha, _, CI, CO, COB = get_const_tuple(kernel.shape) + KH = KW = alpha + 1 - tile_size + assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1 + + if isinstance(N, tvm.tir.Any): + N = tvm.te.size_var("n") + + if not isinstance(H, int) or not isinstance(W, int): + raise RuntimeError( + "adreno winograd conv2d doesn't support dynamic input\ + height or width." + ) + + pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW)) + if layout == "NCHW": + data_pad = nn.pad(data, (0, 0, pt, pl, 0), (0, 0, pb, pr, 0), name="data_pad") + else: + data_pad = nn.pad(data, (0, pt, pl, 0, 0), (0, pb, pr, 0, 0), name="data_pad") + + r = KW + m = tile_size + A, B, G = winograd_transform_matrices(m, r, out_dtype) + + H = (H + pt + pb - KH) // HSTR + 1 + W = (W + pl + pr - KW) // WSTR + 1 + nH, nW = (H + m - 1) // m, (W + m - 1) // m + + P = N * nH * nW if isinstance(N, int) else nH * nW + + # transform kernel + if not pre_computed: + r_kh = te.reduce_axis((0, KH), name="r_kh") + r_kw = te.reduce_axis((0, KW), name="r_kw") + if layout == "NCHW": + kernel_pack = te.compute( + (alpha, alpha, CI, CO, COB), + lambda eps, nu, ci, co, cob: te.sum( + kernel[co][ci][r_kh][r_kw][cob] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw] + ), + name="kernel_pack", + ) + else: + kernel_pack = te.compute( + (alpha, alpha, CI, CO, COB), + lambda eps, nu, ci, co, cob: te.sum( + kernel[r_kh][r_kw][ci][co][cob] * G[eps][r_kh] * G[nu][r_kw], axis=[r_kh, r_kw] + ), + name="kernel_pack", + ) + else: + kernel_pack = kernel + + idxdiv = tvm.tir.indexdiv + idxmod = tvm.tir.indexmod + if layout == "NCHW": + N, CI, H, W, CB = get_const_tuple(data.shape) + else: + N, H, W, CI, CB = get_const_tuple(data.shape) + + # pack input tile + if layout == "NCHW": + input_tile = te.compute( + (alpha, alpha, CI, P, CB), + lambda eps, nu, c, p, cb: data_pad[idxdiv(p, (nH * nW))][c][ + idxmod(idxdiv(p, nW), nH) * m + eps + ][idxmod(p, nW) * m + nu][cb], + name="d", + ) + else: + input_tile = te.compute( + (alpha, alpha, CI, P, CB), + lambda eps, nu, c, p, cb: data_pad[idxdiv(p, (nH * nW))][ + idxmod(idxdiv(p, nW), nH) * m + eps + ][idxmod(p, nW) * m + nu][c][cb], + name="d", + ) + + # transform data + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_a") + data_pack = te.compute( + (P, CI, alpha, alpha, CB), + lambda p, ci, eps, nu, cb: te.sum( + input_tile[r_a][r_b][ci][p][cb] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] + ), + name="data_pack", + ) + + # repack transformed data + data_pack_trans = te.compute( + (alpha, alpha, CI, P, CB), + lambda eps, nu, c, p, cb: data_pack[p][c][eps][nu][cb], + name="data_pack_trans", + ) + + # do batch gemm + ci = te.reduce_axis((0, CI), name="ci") + cb = te.reduce_axis((0, CB), name="cb") + bgemm = te.compute( + (alpha, alpha, CO, P, COB), + lambda eps, nu, co, p, cob: te.sum( + ( + kernel_pack[eps][nu][ci * CB + cb][co][cob] * data_pack_trans[eps][nu][ci][p][cb] + ).astype(args["accumulator"]), + axis=[ci, cb], + ), + name="bgemm", + ) + + # inverse transform + r_a = te.reduce_axis((0, alpha), "r_a") + r_b = te.reduce_axis((0, alpha), "r_a") + inverse = te.compute( + (CO, P, m, m, COB), + lambda co, p, vh, vw, cob: te.sum( + bgemm[r_a][r_b][co][p][cob] * (A[r_a][vh] * A[r_b][vw]).astype(args["accumulator"]), + axis=[r_a, r_b], + ), + name="inverse", + ) + + # output + if layout == "NCHW": + if convert_from4d and autotvm.GLOBAL_SCOPE.in_tuning is False: + output = te.compute( + (N, out_channels, H, W), + lambda n, c, h, w: inverse[c // CB][n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m)][ + idxmod(h, m) + ][idxmod(w, m)][c % CB].astype(out_dtype), + name="output", + tag="cast_from_acc" + args["accumulator"][-2:], + ) + else: + output = te.compute( + (N, CO, H, W, COB), + lambda n, co, h, w, cob: inverse[co][n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m)][ + idxmod(h, m) + ][idxmod(w, m)][cob].astype(out_dtype), + name="output", + tag="cast_from_acc" + args["accumulator"][-2:], + ) + else: + if convert_from4d and autotvm.GLOBAL_SCOPE.in_tuning is False: + output = te.compute( + (N, H, W, out_channels), + lambda n, h, w, c: inverse[c // CB][n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m)][ + idxmod(h, m) + ][idxmod(w, m)][c % CB].astype(out_dtype), + name="output", + tag="cast_from_acc" + args["accumulator"][-2:], + ) + else: + output = te.compute( + (N, H, W, CO, COB), + lambda n, h, w, co, cob: inverse[co][n * nH * nW + idxdiv(h, m) * nW + idxdiv(w, m)][ + idxmod(h, m) + ][idxmod(w, m)][cob].astype(out_dtype), + name="output", + tag="cast_from_acc" + args["accumulator"][-2:], + ) + + if isinstance(N, int): + cfg.add_flop(2 * N * CO * COB * H * W * CI * CB * KH * KW) + + return output + + +def schedule_conv2d_winograd_impl(cfg, outs, tag, pre_computed=False): + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + s = te.create_schedule([x.op for x in outs]) + + def _callback(op): + if op.tag == tag: + schedule_conv2d_winograd(cfg, s, op.output(0), pre_computed=pre_computed) + + traverse_inline(s, outs[0].op, _callback) + return s + + +def schedule_conv2d_winograd(cfg, s, output, pre_computed): + """Schedule winograd template""" + inverse = s[output].op.input_tensors[0] + bgemm, A = s[inverse].op.input_tensors + kernel_pack, data_pack_trans = s[bgemm].op.input_tensors + data_pack = s[data_pack_trans].op.input_tensors[0] + input_tile, B = s[data_pack].op.input_tensors + pad_data = s[input_tile].op.input_tensors[0] + + # data transform + s[B].compute_inline() + s[A].compute_inline() + + # probably will improve real topology execution + if autotvm.GLOBAL_SCOPE.in_tuning: + # Padding to texture + AA = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [input_tile]) + bind_data_copy(s[AA]) + + s[input_tile].compute_inline() + + OL = s.cache_write(data_pack, "local") + c, p, eps, nu, cb = s[data_pack].op.axis + fused = s[data_pack].fuse(c, p, eps, nu) + bx, tx = s[data_pack].split(fused, 128) + s[data_pack].vectorize(cb) + s[data_pack].bind(bx, te.thread_axis("blockIdx.x")) + s[data_pack].bind(tx, te.thread_axis("threadIdx.x")) + + _, _, eps, nu, cb = s[OL].op.axis + r_a, r_b = s[OL].op.reduce_axis + s[OL].unroll(eps) + s[OL].unroll(nu) + s[OL].unroll(r_a) + s[OL].unroll(r_b) + s[OL].vectorize(cb) + s[OL].compute_at(s[data_pack], tx) + s[data_pack].set_scope(get_texture_storage(data_pack.shape)) + + s[data_pack_trans].compute_inline() + + # transform kernel + if not pre_computed: + kernel, G = s[kernel_pack].op.input_tensors + eps, nu, ci, co, cob = s[kernel_pack].op.axis + if autotvm.GLOBAL_SCOPE.in_tuning: + # skip this part during tuning to make recrods accurate + # this part will be pre-computed during pre-compute optimization pass + s[G].pragma(s[G].op.axis[0], "debug_skip_region") + s[kernel_pack].pragma(eps, "debug_skip_region") + else: + s[G].compute_inline() + r_a, r_b = s[kernel_pack].op.reduce_axis + for axis in [eps, nu, r_a, r_b]: + s[kernel_pack].unroll(axis) + + fused = s[kernel_pack].fuse(ci, co) + bb, tt = s[kernel_pack].split(fused, 128) + s[kernel_pack].reorder(bb, tt, eps, nu, r_a, r_b, cob) + s[kernel_pack].vectorize(cob) + s[kernel_pack].bind(bb, te.thread_axis("blockIdx.x")) + s[kernel_pack].bind(tt, te.thread_axis("threadIdx.x")) + else: + kernel = kernel_pack + + if isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in kernel.op.tag: + # manage scheduling of datacopy + pack_data = pad_data.op.input_tensors[0] + bind_data_copy(s[pack_data]) + bind_data_copy(s[kernel]) + elif isinstance(kernel.op, tvm.te.ComputeOp) and "dilate" in kernel.op.tag: + s[kernel].compute_inline() + s[pad_data].compute_inline() + + ##### space definition begin ##### + cfg.define_knob("auto_unroll_max_step", [0, 4, 16]) + b1, b2, y, x, cb = s[bgemm].op.axis + rcc = s[bgemm].op.reduce_axis[0] + alpha = get_const_int(b1.dom.extent) + + cfg.define_split( + "tile_y", y, num_outputs=3, filter=lambda entry: entry.size[2] <= 64 and entry.size[1] <= 8 + ) + cfg.define_split( + "tile_x", + x, + num_outputs=3, + filter=lambda entry: entry.size[2] <= 64 and entry.size[1] >= 4 and entry.size[1] <= 8, + ) + cfg.define_split("tile_rc", rcc, num_outputs=2) + # TODO: Uncomment the following lines when multi_filter will be introduced + # cfg.multi_filter( + # filter=lambda entity: entity["tile_y"].size[2] * entity["tile_x"].size[2] in range(32,1024) + # ) + ##### space definition end ##### + + # batch gemm + OL = s.cache_write(bgemm, "local") + if ( + autotvm.GLOBAL_SCOPE.in_tuning + or isinstance(kernel.op, tvm.te.ComputeOp) + and "filter_pack" in kernel.op.tag + ): + BB = s.cache_read(kernel_pack, get_texture_storage(kernel_pack.shape), [OL]) + bind_data_copy(s[BB]) + + by = s[bgemm].fuse(b1, b2, y) + + # tile and bind spatial axes + bgemm_scope, by = s[bgemm].split(by, nparts=1) + by, vy, ty = cfg["tile_y"].apply(s, bgemm, by) + bx, vx, tx = cfg["tile_x"].apply(s, bgemm, x) + s[bgemm].bind(by, te.thread_axis("blockIdx.y")) + s[bgemm].bind(bx, te.thread_axis("blockIdx.x")) + s[bgemm].bind(vy, te.thread_axis("vthread")) + s[bgemm].bind(vx, te.thread_axis("vthread")) + s[bgemm].bind(ty, te.thread_axis("threadIdx.y")) + s[bgemm].bind(tx, te.thread_axis("threadIdx.x")) + s[bgemm].reorder(bgemm_scope, by, bx, vy, vx, ty, tx, cb) + s[bgemm].vectorize(cb) + s[bgemm].set_scope(get_texture_storage(bgemm.shape)) + + # tile reduction axes + s[OL].compute_at(s[bgemm], tx) + b1, b2, y, x, cb = s[OL].op.axis + (rcc, rcb) = s[OL].op.reduce_axis + b = s[OL].fuse(b1, b2) + s[OL].reorder(b, y, x, rcc, rcb, cb) + # s[OL].unroll(rcb) + s[OL].pragma(rcb, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val) + s[OL].pragma(rcb, "unroll_explicit", True) + s[OL].vectorize(cb) + + # schedule inverse, output and fusion + if output.op in s.outputs: + OL = None + else: + OL = output + s[OL].set_scope("local") + output = s.outputs[0] + + m = alpha - 3 + 1 + if len(s[output].op.axis) == 4: + n, co, h, w = s[output].op.axis + else: + n, co, h, w, _ = s[output].op.axis + ho, wo, hi, wi = s[output].tile(h, w, m, m) + inverse_scope, n = s[output].split(n, nparts=1) + + fused = s[output].fuse(n, co, ho, wo) + bb, tt = s[output].split(fused, 128) + + s[output].bind(bb, te.thread_axis("blockIdx.x")) + s[output].bind(tt, te.thread_axis("threadIdx.x")) + + if OL is not None: + s[OL].compute_at(s[output], tt) + + co, p, vh, vw, cb = s[inverse].op.axis + r_a, r_b = s[inverse].op.reduce_axis + for axis in [vh, vw, r_a, r_b]: + s[inverse].unroll(axis) + s[inverse].vectorize(cb) + s[inverse].compute_at(s[output], tt) + + return s diff --git a/python/tvm/topi/adreno/utils.py b/python/tvm/topi/adreno/utils.py index 3798f9989f4d..78a992e56a0f 100644 --- a/python/tvm/topi/adreno/utils.py +++ b/python/tvm/topi/adreno/utils.py @@ -550,7 +550,7 @@ def get_texture_storage(shape): def infer_tile_size(data, layout): - """Compute the tile size + """Compute the tile size for Winograd algorithm Parameters ---------- diff --git a/src/runtime/opencl/texture_pool.cc b/src/runtime/opencl/texture_pool.cc index 56dd5e2b28d3..8eccacfea5e7 100644 --- a/src/runtime/opencl/texture_pool.cc +++ b/src/runtime/opencl/texture_pool.cc @@ -29,7 +29,7 @@ namespace tvm { namespace runtime { -void* Pool::Alloc(Device dev, DeviceAPI* device, size_t width, size_t height, +void* Pool2D::Alloc(Device dev, DeviceAPI* device, size_t width, size_t height, DLDataType type_hint) { Entry e; Entry new_mem; @@ -107,7 +107,7 @@ void* Pool::Alloc(Device dev, DeviceAPI* device, size_t width, size_t height, return e.data; } -void Pool::Free(void* data) { +void Pool2D::Free(void* data) { Entry e; if (allocated_.back().data == data) { // quick path, last allocated. @@ -125,7 +125,7 @@ void Pool::Free(void* data) { } // Release all resources immediately -void Pool::Release(Device dev, DeviceAPI* device) { +void Pool2D::Release(Device dev, DeviceAPI* device) { for (auto& e : allocated_) { device->FreeDataSpace(dev, e.data); } @@ -156,7 +156,7 @@ void* TexturePool::AllocTexture(Device dev, size_t width, size_t height, DLDataT array_.resize(dev.device_id + 1, nullptr); } if (array_[dev.device_id] == nullptr) { - array_[dev.device_id] = new Pool(); + array_[dev.device_id] = new Pool2D(); } return array_[dev.device_id]->Alloc(dev, device_, width, height, type_hint); } diff --git a/src/runtime/texture.h b/src/runtime/texture.h index 47ff849f9de8..dc38101f0cd4 100644 --- a/src/runtime/texture.h +++ b/src/runtime/texture.h @@ -94,9 +94,9 @@ inline bool IsTextureStorage(std::string scope) { return scope.find("texture") != std::string::npos; } -class TVM_DLL Pool { +class TVM_DLL Pool2D { public: - Pool() = default; + Pool2D() = default; void* Alloc(Device dev, DeviceAPI* device, size_t width, size_t height, DLDataType type_hint); void Free(void* data); // Release all resources immediately @@ -156,7 +156,7 @@ class TVM_DLL TexturePool { private: /*! \brief pool of device local array */ - std::vector array_; + std::vector array_; /*! \brief device type this pool support */ DLDeviceType device_type_; /*! \brief The device API */ diff --git a/tests/cpp-runtime/opencl/opencl_texture_pool_test.cc b/tests/cpp-runtime/opencl/opencl_texture_pool_test.cc index 57e1635a2b53..2d3f43ddce6d 100644 --- a/tests/cpp-runtime/opencl/opencl_texture_pool_test.cc +++ b/tests/cpp-runtime/opencl/opencl_texture_pool_test.cc @@ -26,10 +26,10 @@ using namespace tvm::runtime; using namespace tvm::runtime::cl; -// PoolWrapper is necessary because in class Pool we don't have an access to +// PoolWrapper is necessary because in class Pool2D we don't have an access to // its protected members. In this class we add new methods which allow us to // get and check internal state of class Pool -class PoolWrapper : public Pool { +class PoolWrapper : public Pool2D { public: inline size_t FreeListSize() const { return free_list_.size(); } inline size_t AllocatedListSize() const { return allocated_.size(); } diff --git a/tests/python/relay/test_conv2d_nhwc_texture.py b/tests/python/relay/test_conv2d_nhwc_texture.py index c4bf1e027fbd..96227ca551cf 100644 --- a/tests/python/relay/test_conv2d_nhwc_texture.py +++ b/tests/python/relay/test_conv2d_nhwc_texture.py @@ -21,7 +21,7 @@ import numpy as np from tvm import relay from tvm.relay import testing -from utils.adreno_utils import gpu_preprocess, build_run_compare, gpu_preprocess_nhwc +from utils.adreno_utils import gpu_preprocess, build_run_compare @tvm.testing.requires_opencl diff --git a/tests/python/relay/utils/adreno_utils.py b/tests/python/relay/utils/adreno_utils.py index f7fefed619f8..3bb4a6ada4ec 100644 --- a/tests/python/relay/utils/adreno_utils.py +++ b/tests/python/relay/utils/adreno_utils.py @@ -117,14 +117,3 @@ def gpu_preprocess(tvm_mod): mod = tvm.IRModule.from_expr(tvm_mod) tvm_mod_nchwc = seq(mod) return tvm_mod_nchwc - - -def gpu_preprocess_nhwc(tvm_mod): - layout_config = relay.transform.LayoutConfig() - desired_layouts = {"nn.conv2d": ["NHWC4c", "HWIO4o"]} - with layout_config: - seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)]) - with tvm.transform.PassContext(opt_level=3): - mod = tvm.IRModule.from_expr(tvm_mod) - tvm_mod_nhwcc = seq(mod) - return tvm_mod_nhwcc