From 65fbc0a0813cb4e980e53240f34f3ba18754ab99 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Dec 2021 10:36:36 +0900 Subject: [PATCH] bug fix in im2col encoding --- python/tvm/contrib/cutlass/gen_conv2d.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index d24e988ebe357..4064f5c0c10a4 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -128,16 +128,16 @@ def profile( If profile_all is False, return immediately after the first applicable kernel is found. If use_multiprocessing is True, compile all profiler executables in parallel. """ - B, H, W, C = d_shape - K, R, S, _ = w_shape + B, H, W, IC = d_shape + OC, R, S, _ = w_shape _, P, Q, _ = out_shape - M = B * H * W - K = R * S * C - N = B * P * Q + M = B * P * Q + N = OC + K = R * S * IC gemm_profile_result = self.gemm_profiler.profile( - M, K, N, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing + M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing ) tile_description = gemm_profile_result["tile_description"]