From dbbc6783a8875c99a393a034051a6ef5291db590 Mon Sep 17 00:00:00 2001 From: Egor Churaev Date: Tue, 3 Aug 2021 13:40:24 +0300 Subject: [PATCH] Add vectorization to cuda conv2d_nhwc schedule Adding vectorization significantly improved performance. About 6-7x boost. --- python/tvm/topi/cuda/conv2d_nhwc.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/tvm/topi/cuda/conv2d_nhwc.py b/python/tvm/topi/cuda/conv2d_nhwc.py index e4361e30b5c3b..fc2d8b156291d 100644 --- a/python/tvm/topi/cuda/conv2d_nhwc.py +++ b/python/tvm/topi/cuda/conv2d_nhwc.py @@ -54,6 +54,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): cfg.define_knob("vthread_n", [1] if dynamic_batch else [1, 2]) cfg.define_knob("vthread_c", [1, 2]) cfg.define_knob("step", [16, 3, 32, 64]) + cfg.define_knob("vectorize", [4, 2, 8, 16]) # fallback support target = tvm.target.Target.current() @@ -70,6 +71,7 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): vthread_n = cfg["vthread_n"].val vthread_c = cfg["vthread_c"].val step = cfg["step"].val + vec_factor = cfg["vectorize"].val block_factor_c = tile_c * num_thread_c * vthread_c offset = 8 @@ -86,7 +88,9 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): # Schedule for output ni, hi, wi, fi = s[output].op.axis - bz = s[output].fuse(hi, wi) + bz = wi + fi, vec = s[output].split(fi, factor=vec_factor) + s[output].vectorize(vec) tx, fi = s[output].split(fi, factor=tile_c) txz, tx = s[output].split(tx, factor=num_thread_c) bx, txz = s[output].split(txz, factor=vthread_c) @@ -125,6 +129,8 @@ def schedule_conv2d_nhwc_direct(cfg, s, Conv): _, _, ic, o = s[WW].op.axis t = s[WW].fuse(ic, o) s[WW].storage_align(ic, W_align - 1, W_align) + t, vec = s[WW].split(t, factor=vec_factor) + s[WW].vectorize(vec) ty, tx = s[WW].split(t, factor=num_thread_c) _, ty = s[WW].split(ty, factor=num_thread_n) s[WW].bind(tx, thread_x)