From 19a7256e40e4dd9565ce6471a86cb98e5498585a Mon Sep 17 00:00:00 2001 From: masahi Date: Thu, 2 Dec 2021 10:34:46 +0900 Subject: [PATCH] [CUTLASS] Initial conv2d support (#9595) * Add initial conv generator * added conv2d pattern * profile by gemm profiler * remove conv2d profiler for now * remove unused code * add default * minor fix, profiling working * start codegen * generated code compiled * fixed layout initialization * matched with autotvm tensorcore result * test refactor * minor cleanup * remove iteration algo "Analytic" * add test for dynamic batch conv2d * pass dl tensor as output too * support conv2d dynamic shape in codegen * test working * lint * simplify codegen * fix weird formatting * typo fix * check if cutlass is enabled in the test * simplify gen_conv2d.py --- python/tvm/contrib/cutlass/build.py | 90 ++++++- .../tvm/contrib/cutlass/conv2d_operation.py | 240 ++++++++++++++++++ python/tvm/contrib/cutlass/gen_conv2d.py | 147 +++++++++++ python/tvm/contrib/cutlass/gen_gemm.py | 3 + python/tvm/contrib/cutlass/library.py | 57 ++++- python/tvm/relay/op/contrib/cutlass.py | 7 + python/tvm/relay/op/nn/_nn.py | 15 ++ .../backend/contrib/codegen_c/codegen_c.h | 12 +- src/relay/backend/contrib/cutlass/codegen.cc | 134 +++++++++- tests/python/contrib/test_cutlass.py | 96 +++++++ 10 files changed, 776 insertions(+), 25 deletions(-) create mode 100644 python/tvm/contrib/cutlass/conv2d_operation.py create mode 100644 python/tvm/contrib/cutlass/gen_conv2d.py diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 615b9003adc93..c3a8fdc1ad8ca 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -23,6 +23,7 @@ from tvm import runtime, relay from tvm.contrib.nvcc import find_cuda_path, get_cuda_version from .gen_gemm import CutlassGemmProfiler +from .gen_conv2d import CutlassConv2DProfiler logger = logging.getLogger("cutlass") @@ -65,7 +66,7 @@ def _get_cutlass_compile_options(sm, threads): return kwargs -class GemmAnnotator(tvm.relay.ExprVisitor): +class OpAnnotator(tvm.relay.ExprVisitor): """Annotates partitioned functions with shape and dtype information.""" def __init__(self): @@ -81,6 +82,10 @@ def visit_call(self, call): self.signature["arg%d_dtype" % i] = arg.checked_type.dtype self.signature["ret_shape"] = op.ret_type.shape self.signature["ret_dtype"] = op.ret_type.dtype + self.visit(op.body) + + if str(op) == "nn.conv2d": + self.op_attrs = call.attrs def select_gemm_kernel( @@ -125,6 +130,8 @@ def handle_batch_matmul( else: raise ValueError("%s pattern is not implemented." % op_type) + assert "tn_align" in out["name"], "Only supports (row_major, col_major) input layout for now." + return { "batch": arg0_shape[0], "batch_stride_A": arg0_shape[1] * arg0_shape[2], @@ -132,6 +139,9 @@ def handle_batch_matmul( "batch_stride_C": arg0_shape[1] * arg1_shape[1], "cutlass_op_def": cutlass_op_def, "cutlass_op_name": out["name"], + "lda": "K", + "ldb": "K", + "ldc": "N", } @@ -158,6 +168,50 @@ def handle_dense( else: raise ValueError("%s pattern is not implemented." % op_type) + assert "tn_align" in out["name"], "Only supports (row_major, col_major) input layout for now." + + return { + "cutlass_op_def": cutlass_op_def, + "cutlass_op_name": out["name"], + "lda": "K", + "ldb": "K", + "ldc": "N", + } + + +def handle_conv2d( + cutlass_profiler, + op_type, + d_shape, + w_shape, + out_shape, + out_dtype, + profile_all, + use_multiprocessing, +): + """Profile and select a kernel for conv2d op workload.""" + if any(isinstance(s, tvm.tir.Any) for s in d_shape): + out = cutlass_profiler.get_default(out_dtype) + logger.info("Picked the default kernel %s", out["name"]) + else: + out = cutlass_profiler.profile( + d_shape, + w_shape, + out_shape, + out_dtype, + profile_all=profile_all, + use_multiprocessing=use_multiprocessing, + ) + if profile_all: + logger.info("The best kernel is %s", out["name"]) + else: + logger.info("Picked the first kernel found %s", out["name"]) + + if op_type == "cutlass.conv2d": + cutlass_op_def = out["opdef"] + else: + raise ValueError("%s pattern is not implemented." % op_type) + return { "cutlass_op_def": cutlass_op_def, "cutlass_op_name": out["name"], @@ -195,12 +249,13 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t num_cutlass_partition : int The number of partitioned functions created for CUTLASS. """ - cutlass_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir) + gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir) + conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), tmp_dir) num_cutlass_partition = 0 for var in mod.get_global_vars(): fun_name = var.name_hint func = mod[fun_name] - annotator = GemmAnnotator() + annotator = OpAnnotator() if "cutlass" in fun_name: num_cutlass_partition += 1 annotator.visit(func) @@ -213,10 +268,26 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t arg0_shape = new_attrs["arg0_shape"] arg1_shape = new_attrs["arg1_shape"] - if "batch_matmul" in op_type: + if "conv2d" in op_type: + new_attrs["padding"] = annotator.op_attrs.padding + new_attrs["strides"] = annotator.op_attrs.strides + new_attrs["dilation"] = annotator.op_attrs.dilation + new_attrs.update( + handle_conv2d( + conv2d_profiler, + op_type, + arg0_shape, + arg1_shape, + annotator.signature["ret_shape"], + out_dtype, + profile_all, + use_multiprocessing, + ) + ) + elif "batch_matmul" in op_type: new_attrs.update( handle_batch_matmul( - cutlass_profiler, + gemm_profiler, op_type, arg0_shape, arg1_shape, @@ -228,7 +299,7 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t elif "dense" in op_type: new_attrs.update( handle_dense( - cutlass_profiler, + gemm_profiler, op_type, arg0_shape, arg1_shape, @@ -240,13 +311,6 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t else: raise ValueError("%s unsupported composite" % op_type) - if new_attrs["cutlass_op_name"].find("_tn_align") > 0: - new_attrs["lda"] = "K" - new_attrs["ldb"] = "K" - new_attrs["ldc"] = "N" - else: - raise ValueError("%s unsupported operation" % new_attrs["cutlass_op_name"]) - new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs) new_func = relay.Function( func.params, diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py b/python/tvm/contrib/cutlass/conv2d_operation.py new file mode 100644 index 0000000000000..8a886ff260b81 --- /dev/null +++ b/python/tvm/contrib/cutlass/conv2d_operation.py @@ -0,0 +1,240 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name, unused-wildcard-import, wildcard-import +"""Generator for CUTLASS Conv2D kernels.""" +from .library import * + + +class Conv2dOperation: + """Describes various attributes for instantiating Conv2d kernels.""" + + def __init__( + self, + conv_kind, + iterator_algorithm, + arch, + tile_description, + A, + B, + C, + element_epilogue, + stride_support, + epilogue_functor=EpilogueFunctor.LinearCombination, + swizzling_functor=SwizzlingFunctor.Identity1, + ): + self.operation_kind = OperationKind.Conv2d + self.arch = arch + self.tile_description = tile_description + self.conv_kind = conv_kind + self.A = A + self.B = B + self.C = C + self.element_epilogue = element_epilogue + self.epilogue_functor = epilogue_functor + self.iterator_algorithm = iterator_algorithm + self.stride_support = stride_support + self.swizzling_functor = swizzling_functor + + def accumulator_type(self): + return self.tile_description.math_instruction.element_accumulator + + def core_name(self): + """ The basic operation kind is prefixed with a letter indicating the accumulation type. """ + intermediate_type = "" + + if self.tile_description.math_instruction.opcode_class == OpcodeClass.TensorOp: + inst_shape = "%d%d%d" % tuple(self.tile_description.math_instruction.instruction_shape) + if ( + self.tile_description.math_instruction.element_a != self.A.element + and self.tile_description.math_instruction.element_a != self.accumulator_type() + ): + intermediate_type = DataTypeNames[self.tile_description.math_instruction.element_a] + else: + inst_shape = "" + + return "%s%s%s%s_%s" % ( + ShortDataTypeNames[self.accumulator_type()], + inst_shape, + intermediate_type, + ConvKindNames[self.conv_kind], + IteratorAlgorithmNames[self.iterator_algorithm], + ) + + def extended_name(self): + """ Append data types if they differ from compute type. """ + if ( + self.C.element != self.tile_description.math_instruction.element_accumulator + and self.A.element != self.tile_description.math_instruction.element_accumulator + ): + extended_name = "${element_c}_${core_name}_${element_a}" + elif ( + self.C.element == self.tile_description.math_instruction.element_accumulator + and self.A.element != self.tile_description.math_instruction.element_accumulator + ): + extended_name = "${core_name}_${element_a}" + else: + extended_name = "${core_name}" + + extended_name = substitute_template( + extended_name, + { + "element_a": DataTypeNames[self.A.element], + "element_c": DataTypeNames[self.C.element], + "core_name": self.core_name(), + }, + ) + + return extended_name + + def layout_name(self): + return "%s" % (ShortLayoutTypeNames[self.A.layout]) + + def procedural_name(self): + """ + The full procedural name indicates architecture, extended name, tile size, and layout. + """ + opcode_class_name = OpcodeClassNames[self.tile_description.math_instruction.opcode_class] + + threadblock = "%dx%d_%dx%d" % ( + self.tile_description.threadblock_shape[0], + self.tile_description.threadblock_shape[1], + self.tile_description.threadblock_shape[2], + self.tile_description.stages, + ) + + if self.stride_support == StrideSupport.Unity: + configuration_name = ( + "cutlass_${opcode_class}_${extended_name}_${threadblock}" + "_${layout}_align${alignment}_unity_stride" + ) + else: + configuration_name = ( + "cutlass_${opcode_class}_${extended_name}_${threadblock}" + "_${layout}_align${alignment}" + ) + + return substitute_template( + configuration_name, + { + "opcode_class": opcode_class_name, + "extended_name": self.extended_name(), + "threadblock": threadblock, + "layout": self.layout_name(), + "alignment": "%d" % self.A.alignment, + }, + ) + + +class EmitConv2dInstance: + """ Responsible for emitting a CUTLASS template definition.""" + + def __init__(self): + self.template = """ + // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" + using ${operation_name} = + typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accumulator}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>, + cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >, + cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>, + ${epilogue_functor}< + ${element_c}, + ${epilogue_vector_length}, + ${element_accumulator}, + ${element_epilogue} + >, + ${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>, + ${stages}, + ${math_operator}, + ${iterator_algorithm}, + ${stride_support}, + ${align_a}, + ${align_b} + >::Kernel; +""" + + def emit(self, operation): + """Instantiate a Conv2d kernel from given `operation`.""" + warp_shape = [ + int( + operation.tile_description.threadblock_shape[idx] + / operation.tile_description.warp_count[idx] + ) + for idx in range(3) + ] + + epilogue_vector_length = int( + min(operation.C.alignment * DataTypeSize[operation.C.element], 128) + / DataTypeSize[operation.C.element] + ) + + values = { + "operation_name": operation.procedural_name(), + "conv_kind": ConvKindTag[operation.conv_kind], + "conv_kind_name": ConvKindNames[operation.conv_kind].capitalize(), + "element_a": DataTypeTag[operation.A.element], + "layout_a": LayoutTag[operation.A.layout], + "element_b": DataTypeTag[operation.B.element], + "layout_b": LayoutTag[operation.B.layout], + "element_c": DataTypeTag[operation.C.element], + "layout_c": LayoutTag[operation.C.layout], + "element_accumulator": DataTypeTag[operation.accumulator_type()], + "opcode_class": OpcodeClassTag[ + operation.tile_description.math_instruction.opcode_class + ], + "arch": "cutlass::arch::Sm%d" % operation.arch, + "threadblock_shape_m": str(operation.tile_description.threadblock_shape[0]), + "threadblock_shape_n": str(operation.tile_description.threadblock_shape[1]), + "threadblock_shape_k": str(operation.tile_description.threadblock_shape[2]), + "warp_shape_m": str(warp_shape[0]), + "warp_shape_n": str(warp_shape[1]), + "warp_shape_k": str(warp_shape[2]), + "instruction_shape_m": str( + operation.tile_description.math_instruction.instruction_shape[0] + ), + "instruction_shape_n": str( + operation.tile_description.math_instruction.instruction_shape[1] + ), + "instruction_shape_k": str( + operation.tile_description.math_instruction.instruction_shape[2] + ), + "epilogue_vector_length": str(epilogue_vector_length), + "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], + "element_epilogue": str(DataTypeTag[operation.element_epilogue]), + "swizzling_functor": SwizzlingFunctorTag[operation.swizzling_functor], + "stages": str(operation.tile_description.stages), + "iterator_algorithm": IteratorAlgorithmTag[operation.iterator_algorithm], + "iterator_algorithm_name": IteratorAlgorithmNames[ + operation.iterator_algorithm + ].capitalize(), + "stride_support": StrideSupportTag[operation.stride_support], + "math_operator": MathOperationTag[ + operation.tile_description.math_instruction.math_operation + ], + "align_a": str(operation.A.alignment), + "align_b": str(operation.B.alignment), + } + + return substitute_template(self.template, values) diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py new file mode 100644 index 0000000000000..d24e988ebe357 --- /dev/null +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -0,0 +1,147 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Conv2d kernel generator and profiler for CUTLASS.""" +from .conv2d_operation import Conv2dOperation, EmitConv2dInstance +from .gen_gemm import CutlassGemmProfiler +from .library import ( + EpilogueFunctor, + SwizzlingFunctor, + TensorDescription, + LayoutType, + ConvKind, + StrideSupport, + IteratorAlgorithm, +) + + +def create_conv2d_operator( + tile_descriptions, + data_type, + alignment_constraints, + swizzling_functor=SwizzlingFunctor.Identity4, +): + """Exhaustively instantiate all kernels from a given configuration.""" + ret = [] + + kernel_emitter = EmitConv2dInstance() + + element_a, element_b, element_c, element_epilogue = data_type + iterator_algorithms = [IteratorAlgorithm.Optimized] + + layout = (LayoutType.TensorNHWC, LayoutType.TensorNHWC, LayoutType.TensorNHWC) + for tile in tile_descriptions: + for alignment in alignment_constraints: + + alignment_c = min(8, alignment) + + A = TensorDescription(element_a, layout[0], alignment) + B = TensorDescription(element_b, layout[1], alignment) + C = TensorDescription(element_c, layout[2], alignment_c) + + swizzling_functor_ = swizzling_functor + + for iterator_algorithm in iterator_algorithms: + op_entry = {} + + op = Conv2dOperation( + ConvKind.Fprop, + iterator_algorithm, + tile.minimum_compute_capability, + tile, + A, + B, + C, + element_epilogue, + StrideSupport.Strided, + EpilogueFunctor.LinearCombination, + swizzling_functor_, + ) + + # TODO(masahi): Add profiler source here + op_entry["opdef"] = kernel_emitter.emit(op) + op_entry["op"] = op + op_entry["name"] = op.procedural_name() + op_entry["runtime"] = 9999999 + + # fused ops + for epilogue, opdef in zip( + [ + EpilogueFunctor.LinearCombinationBias, + EpilogueFunctor.LinearCombinationRelu, + ], + ["opdef_bias", "opdef_bias_relu"], + ): + op = Conv2dOperation( + ConvKind.Fprop, + iterator_algorithm, + tile.minimum_compute_capability, + tile, + A, + B, + C, + element_epilogue, + StrideSupport.Strided, + epilogue, + swizzling_functor_, + ) + + op_entry[opdef] = kernel_emitter.emit(op) + + ret.append(op_entry) + + return ret + + +class CutlassConv2DProfiler: + """Profile all candidate kernels and select the best one.""" + + def __init__(self, sm, cutlass_path, binary_path): + self.gemm_profiler = CutlassGemmProfiler(sm, cutlass_path, binary_path) + self.sm = sm + + def get_default(self, out_dtype): + gemm_profile_result = self.gemm_profiler.get_default(out_dtype) + tile_description = gemm_profile_result["tile_description"] + alignment = gemm_profile_result["alignment"] + data_type = gemm_profile_result["data_type"] + return create_conv2d_operator([tile_description], data_type, [alignment])[0] + + def profile( + self, d_shape, w_shape, out_shape, out_dtype, profile_all=True, use_multiprocessing=False + ): + """Profile and select the best kernel from candidate kernels. + If profile_all is False, return immediately after the first applicable kernel is found. + If use_multiprocessing is True, compile all profiler executables in parallel. + """ + B, H, W, C = d_shape + K, R, S, _ = w_shape + _, P, Q, _ = out_shape + + M = B * H * W + K = R * S * C + N = B * P * Q + + gemm_profile_result = self.gemm_profiler.profile( + M, K, N, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing + ) + + tile_description = gemm_profile_result["tile_description"] + alignment = gemm_profile_result["alignment"] + data_type = gemm_profile_result["data_type"] + + return create_conv2d_operator([tile_description], data_type, [alignment])[0] diff --git a/python/tvm/contrib/cutlass/gen_gemm.py b/python/tvm/contrib/cutlass/gen_gemm.py index 4025354dc7398..cec64f0af974c 100644 --- a/python/tvm/contrib/cutlass/gen_gemm.py +++ b/python/tvm/contrib/cutlass/gen_gemm.py @@ -125,6 +125,9 @@ def create_gemm_operator( op.leading_dim(), ) op_entry["runtime"] = 9999999 + op_entry["tile_description"] = tile_description + op_entry["alignment"] = alignment + op_entry["data_type"] = data_type ret.append(op_entry) return ret diff --git a/python/tvm/contrib/cutlass/library.py b/python/tvm/contrib/cutlass/library.py index a3b90ff83d1f7..902dc57100a98 100644 --- a/python/tvm/contrib/cutlass/library.py +++ b/python/tvm/contrib/cutlass/library.py @@ -64,23 +64,27 @@ class MathOperation(enum.Enum): class LayoutType(enum.Enum): ColumnMajor = enum_auto() RowMajor = enum_auto() + TensorNHWC = enum_auto() LayoutTag = { LayoutType.ColumnMajor: "cutlass::layout::ColumnMajor", LayoutType.RowMajor: "cutlass::layout::RowMajor", + LayoutType.TensorNHWC: "cutlass::layout::TensorNHWC", } TransposedLayout = { LayoutType.ColumnMajor: LayoutType.RowMajor, LayoutType.RowMajor: LayoutType.ColumnMajor, + LayoutType.TensorNHWC: LayoutType.TensorNHWC, } ShortLayoutTypeNames = { LayoutType.ColumnMajor: "n", LayoutType.RowMajor: "t", + LayoutType.TensorNHWC: "nhwc", } @@ -105,11 +109,10 @@ class OpcodeClass(enum.Enum): class OperationKind(enum.Enum): Gemm = enum_auto() + Conv2d = enum_auto() -OperationKindNames = { - OperationKind.Gemm: "gemm", -} +OperationKindNames = {OperationKind.Gemm: "gemm", OperationKind.Conv2d: "conv2d"} class Target(enum.Enum): @@ -172,6 +175,54 @@ class SwizzlingFunctor(enum.Enum): } +class ConvKind(enum.Enum): + Fprop = enum_auto() + + +ConvKindTag = { + ConvKind.Fprop: "cutlass::conv::Operator::kFprop", +} + + +ConvKindNames = { + ConvKind.Fprop: "fprop", +} + + +class StrideSupport(enum.Enum): + Strided = enum_auto() + Unity = enum_auto() + + +StrideSupportTag = { + StrideSupport.Strided: "cutlass::conv::StrideSupport::kStrided", + StrideSupport.Unity: "cutlass::conv::StrideSupport::kUnity", +} + + +StrideSupportNames = { + StrideSupport.Strided: "", + StrideSupport.Unity: "unity_stride", +} + + +class IteratorAlgorithm(enum.Enum): + Analytic = enum_auto() + Optimized = enum_auto() + + +IteratorAlgorithmTag = { + IteratorAlgorithm.Analytic: "cutlass::conv::IteratorAlgorithm::kAnalytic", + IteratorAlgorithm.Optimized: "cutlass::conv::IteratorAlgorithm::kOptimized", +} + + +IteratorAlgorithmNames = { + IteratorAlgorithm.Analytic: "analytic", + IteratorAlgorithm.Optimized: "optimized", +} + + class MathInstruction: """Describe characteristics of a math instruction.""" diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 8ed371844a1ce..4ae529e18dc2c 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -55,6 +55,11 @@ def make_batch_matmul_pattern(): return is_op("nn.batch_matmul")(wildcard(), wildcard()) +def make_conv2d_pattern(): + # TODO(masahi): Check layout and alignment + return is_op("nn.conv2d")(wildcard(), wildcard()) + + def partition_for_cutlass(mod): """Partition the input module into CUTLASS-supported subgraphs.""" dense_pat = ("cutlass.dense", make_gemm_pattern(False, None)) @@ -72,6 +77,8 @@ def partition_for_cutlass(mod): dense_bias_pat, dense_pat, ("cutlass.batch_matmul", make_batch_matmul_pattern()), + # TODO(masahi): Add more conv2d patterns + ("cutlass.conv2d", make_conv2d_pattern()), ] mod = transform.MergeComposite(cutlass_patterns)(mod) mod = transform.AnnotateTarget(["cutlass"])(mod) diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index 17f75a07af642..8357f28103a06 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -1090,6 +1090,19 @@ def _conv_shape_func_nhwc_hwoi(dshape, kshape, strides, padding, dilation): return out +@script +def _conv_shape_func_nhwc_ohwi(dshape, kshape, strides, padding, dilation): + """Shape function for conv*d op with nhwc & ohwi layout.""" + out = output_tensor((dshape.shape[0],), "int64") + out[0] = dshape[0] + out[dshape.shape[0] - 1] = kshape[0] + + for i in const_range(dshape.shape[0] - 2): + dilated_k = (kshape[i + 1] - 1) * dilation[i] + 1 + out[i + 1] = (dshape[i + 1] + 2 * padding[i] - dilated_k) // strides[i] + 1 + return out + + def conv_shape_func(attrs, inputs, _): """Shape function for conv*d op.""" strides = get_const_tuple(attrs.strides) @@ -1103,6 +1116,8 @@ def conv_shape_func(attrs, inputs, _): shape_func = _conv_shape_func_nhwc_hwio elif attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWOI": shape_func = _conv_shape_func_nhwc_hwoi + elif attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "OHWI": + shape_func = _conv_shape_func_nhwc_ohwi else: raise ValueError( "Unsupported data/kernel layout: %s, %s" diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 617da1e3fa818..49a5bca068d1b 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -191,10 +191,18 @@ class CodegenCBase { PrintIndents(); } for (size_t i = 0; i < outs.size() - 1; i++) { - code_stream_ << "(" << outs[i].dtype << "*)(out" << i << "->data),\n"; + if (pass_dl_tensor) { + code_stream_ << "out" << i << ",\n"; + } else { + code_stream_ << "(" << outs[i].dtype << "*)(out" << i << "->data),\n"; + } PrintIndents(); } - code_stream_ << "(" << outs.back().dtype << "*)(out" << outs.size() - 1 << "->data));\n"; + if (pass_dl_tensor) { + code_stream_ << "out" << outs.size() - 1 << ");\n"; + } else { + code_stream_ << "(" << outs.back().dtype << "*)(out" << outs.size() - 1 << "->data));\n"; + } PrintIndents(); code_stream_ << "return 0;\n"; ExitScope(); diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index f154f8641a645..c226da5864fcc 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -61,7 +61,7 @@ inline void CutlassPrint(std::ostringstream& os, const std::string& stmt, int in os << stmt; } -Str2StrMap GemmArgsCommon(const Map& attrs) { +Str2StrMap ArgsCommon(const Map& attrs) { Str2StrMap args; auto arg0_dtype = std::string(attrs["arg0_dtype"].as()->data); auto arg1_dtype = std::string(attrs["arg1_dtype"].as()->data); @@ -72,6 +72,11 @@ Str2StrMap GemmArgsCommon(const Map& attrs) { args["op_def"] = std::string(attrs["cutlass_op_def"].as()->data); args["op_name"] = std::string(attrs["cutlass_op_name"].as()->data); args["op_type"] = std::string(attrs["op_type"].as()->data); + return args; +} + +Str2StrMap GemmArgsCommon(const Map& attrs) { + Str2StrMap args = ArgsCommon(attrs); args["lda"] = std::string(attrs["lda"].as()->data); args["ldb"] = std::string(attrs["ldb"].as()->data); args["ldc"] = std::string(attrs["ldc"].as()->data); @@ -110,7 +115,7 @@ void AppendPrologue(std::ostringstream& gemm_decl, const Str2StrMap& attrs, CutlassPrint(gemm_decl, "using ElementOutput = " + attrs.at("ElementOutput") + ";\n"); CutlassPrint(gemm_decl, "using ElementComputeEpilogue = " + attrs.at("ElementOutput") + ";\n"); CutlassPrint(gemm_decl, attrs.at("op_def")); - CutlassPrint(gemm_decl, "using Gemm = Operation_" + attrs.at("op_name") + ";\n"); + CutlassPrint(gemm_decl, "using " + kernel + " = Operation_" + attrs.at("op_name") + ";\n"); auto get_dim = [&attrs, &func_args](const std::string& axis, int arg_idx, int axis_idx) { if (attrs.at(axis) == kAnyDim) { @@ -139,9 +144,8 @@ void AppendPrologue(std::ostringstream& gemm_decl, const Str2StrMap& attrs, CutlassPrint(gemm_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n"); } - CutlassPrint(gemm_decl, "void* ptr_out = (void*)(out0);\n"); + CutlassPrint(gemm_decl, "void* ptr_out = (void*)(out0->data);\n"); - CutlassPrint(gemm_decl, "using " + kernel + " = Operation_" + attrs.at("op_name") + ";\n"); CutlassPrint(gemm_decl, "typename " + kernel + "::Arguments arguments{\n"); CutlassPrint(gemm_decl, " problem_size,\n"); } @@ -234,6 +238,112 @@ std::string BatchMatmulOp(std::string id, const Str2StrMap& attrs, return gemm_decl.str(); } +Str2StrMap Conv2dArgs(const Map& attrs) { + Str2StrMap args = ArgsCommon(attrs); + auto arg0_shape = attrs["arg0_shape"].as(); + auto arg1_shape = attrs["arg1_shape"].as(); + auto out_shape = attrs["ret_shape"].as(); + args["N"] = GetDimAsStr(arg0_shape->at(0)); + args["H"] = GetDimAsStr(arg0_shape->at(1)); + args["W"] = GetDimAsStr(arg0_shape->at(2)); + args["C"] = GetDimAsStr(arg0_shape->at(3)); + args["K"] = GetDimAsStr(arg1_shape->at(0)); + args["R"] = GetDimAsStr(arg1_shape->at(1)); + args["S"] = GetDimAsStr(arg1_shape->at(1)); + args["P"] = GetDimAsStr(out_shape->at(1)); + args["Q"] = GetDimAsStr(out_shape->at(2)); + args["pad_h"] = GetDimAsStr(attrs["padding"].as()->at(0)); + args["pad_w"] = GetDimAsStr(attrs["padding"].as()->at(1)); + args["stride_h"] = GetDimAsStr(attrs["strides"].as()->at(0)); + args["stride_w"] = GetDimAsStr(attrs["strides"].as()->at(1)); + args["dilation_h"] = GetDimAsStr(attrs["dilation"].as()->at(0)); + args["dilation_w"] = GetDimAsStr(attrs["dilation"].as()->at(1)); + return args; +} + +std::string Conv2dOp(std::string id, const Str2StrMap& attrs, + const std::vector& func_args) { + std::ostringstream conv2d_decl; + CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n"); + CutlassPrint(conv2d_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n"); + CutlassPrint(conv2d_decl, "using ElementOutput = " + attrs.at("ElementOutput") + ";\n"); + + CutlassPrint(conv2d_decl, "using ElementComputeEpilogue = " + attrs.at("ElementOutput") + ";\n"); + CutlassPrint(conv2d_decl, attrs.at("op_def")); + CutlassPrint(conv2d_decl, "using Operation_" + attrs.at("op_name") + + " = cutlass::conv::device::ImplicitGemmConvolution<" + + attrs.at("op_name") + ">;\n"); + CutlassPrint(conv2d_decl, "using Conv2d = Operation_" + attrs.at("op_name") + ";\n"); + + auto get_dim = [&attrs](const std::string& axis, const std::string& var_name, int axis_idx) { + if (attrs.at(axis) == kAnyDim) { + return var_name + "->shape[" + std::to_string(axis_idx) + "]"; + } else { + return attrs.at(axis); + } + }; + + CutlassPrint(conv2d_decl, "int N = " + get_dim("N", func_args[0], 0) + ";\n"); + CutlassPrint(conv2d_decl, "int H = " + get_dim("H", func_args[0], 1) + ";\n"); + CutlassPrint(conv2d_decl, "int W = " + get_dim("W", func_args[0], 2) + ";\n"); + CutlassPrint(conv2d_decl, "int C = " + attrs.at("C") + ";\n"); + CutlassPrint(conv2d_decl, "int K = " + attrs.at("K") + ";\n"); + CutlassPrint(conv2d_decl, "int R = " + attrs.at("R") + ";\n"); + CutlassPrint(conv2d_decl, "int S = " + attrs.at("S") + ";\n"); + CutlassPrint(conv2d_decl, "int P = " + get_dim("P", "out0", 1) + ";\n"); + CutlassPrint(conv2d_decl, "int Q = " + get_dim("Q", "out0", 2) + ";\n"); + CutlassPrint(conv2d_decl, "int pad_h = " + attrs.at("pad_h") + ";\n"); + CutlassPrint(conv2d_decl, "int pad_w = " + attrs.at("pad_w") + ";\n"); + CutlassPrint(conv2d_decl, "int stride_h = " + attrs.at("stride_h") + ";\n"); + CutlassPrint(conv2d_decl, "int stride_w = " + attrs.at("stride_w") + ";\n"); + CutlassPrint(conv2d_decl, "int dilation_h = " + attrs.at("dilation_h") + ";\n"); + CutlassPrint(conv2d_decl, "int dilation_w = " + attrs.at("dilation_w") + ";\n"); + + CutlassPrint( + conv2d_decl, + "cutlass::conv::Conv2dProblemSize problem_size(N, H, W, C, K, R, S, P, Q, pad_h, pad_w, " + "stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, 1);\n"); + + ICHECK(func_args.size() >= 2); + CutlassPrint(conv2d_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n"); + CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n"); + CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n"); + CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n"); + CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n"); + + CutlassPrint(conv2d_decl, "using cutlass::layout::TensorNHWC;\n"); + CutlassPrint(conv2d_decl, + "TensorNHWC layout_A(TensorNHWC::packed(cutlass::make_Coord(N, H, W, C)));\n"); + CutlassPrint(conv2d_decl, + "TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(K, R, S, C)));\n"); + CutlassPrint(conv2d_decl, + "TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n"); + CutlassPrint(conv2d_decl, "typename Conv2d::Arguments arguments{\n"); + CutlassPrint(conv2d_decl, " problem_size,\n"); + CutlassPrint(conv2d_decl, " {static_cast(ptr_a), layout_A},\n"); + CutlassPrint(conv2d_decl, " {static_cast(ptr_b), layout_B},\n"); + CutlassPrint(conv2d_decl, " {static_cast(ptr_out),layout_C},\n"); + CutlassPrint(conv2d_decl, " {static_cast(ptr_out),layout_C},\n"); + CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n"); + CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n"); + + CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n"); + // Allocate workspace memory + CutlassPrint(conv2d_decl, + "cutlass::device_memory::allocation workspace(workspace_size);\n"); + // Check the problem size is supported or not + CutlassPrint(conv2d_decl, "cutlass::Status status = conv2d_op.can_implement(arguments);\n"); + CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); + // Initialize CUTLASS kernel with arguments and workspace pointer + CutlassPrint(conv2d_decl, "status = conv2d_op.initialize(arguments, workspace.get());\n"); + CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); + // Launch initialized CUTLASS kernel + CutlassPrint(conv2d_decl, "status = conv2d_op();\n"); + CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); + + return conv2d_decl.str(); +} + class CodegenCutlass : public MemoizedExprTranslator>, public CodegenCBase { public: CodegenCutlass(const std::string& id, const Map& attrs) { @@ -268,9 +378,9 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi code_stream_ << "DLTensor* " << arg->name_hint() << ", "; } for (size_t i = 0; i < out.size() - 1; ++i) { - code_stream_ << out[i].dtype << "* out" << i << ", "; + code_stream_ << "DLTensor* out" << i << ", "; } - code_stream_ << out.back().dtype << "* out" << out.size() - 1 << ") {\n"; + code_stream_ << "DLTensor* out" << out.size() - 1 << ") {\n"; this->EnterScope(); // Function body @@ -347,7 +457,12 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi GetRootCall(callee->body.as(), 0, {"nn.batch_matmul"}); return GenerateBody(batch_matmul_call, "cutlass_batch_matmul", GetArgumentNames(caller), BatchMatmulArgs(std::ref(attrs_))); + } else if (pattern_name == "cutlass.conv2d") { + const auto* conv2d_call = GetRootCall(callee->body.as(), 0, {"nn.conv2d"}); + return GenerateBody(conv2d_call, "cutlass_conv2d", GetArgumentNames(caller), + Conv2dArgs(std::ref(attrs_))); } + LOG(FATAL) << "Unknown composite function: " << pattern_name; return {}; } @@ -392,7 +507,10 @@ class CodegenCutlass : public MemoizedExprTranslator>, publi ret.decl = DenseOp(ext_func_id_, attribute_args, func_args); } else if (func_name == "cutlass_batch_matmul") { ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args); + } else if (func_name == "cutlass_conv2d") { + ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args); } + return ret; } /*! \brief The id of the external cutlass ext_func. */ @@ -441,10 +559,12 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase { // cutlass header code_stream_ << "#include \n"; code_stream_ << "#include \n"; + code_stream_ << "#include \n"; code_stream_ << "#include \n"; - code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index 6f27d57d95d7a..a258da3c5d788 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -110,6 +110,22 @@ def get_batch_matmul(batch, M, N, K, out_dtype="float16"): return get_batch_matmul_with_shape((batch, M, K), (batch, N, K), out_dtype="float16") +def get_conv2d_nchw(d_shape, w_shape): + data = relay.var("data", shape=d_shape, dtype="float16") + weight = relay.var("weight", shape=w_shape, dtype="float16") + out_channel = w_shape[0] + return tvm.IRModule.from_expr( + relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=(3, 3), + channels=out_channel, + padding=(1, 1), + out_dtype="float16", + ) + ) + + def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"): mod = partition_for_cutlass(mod) mod, num_cutlass_partition = tune_cutlass_kernels( @@ -289,5 +305,85 @@ def test_batch_matmul(): ) +def convert_conv2d_layout(mod, desired_layouts): + with tvm.transform.PassContext(opt_level=3): + seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)]) + return seq(mod) + + +def verify_conv2d( + mod_nchw, + mod_ref, + d_shape, + w_shape, + sm=80, + atol=1e-5, + rtol=1e-5, + run_benchmark=False, +): + if not has_cutlass(): + return + + np_data = np.random.uniform(-1, 1, d_shape).astype("float16") + np_weight = np.random.uniform(-1, 1, w_shape).astype("float16") + + params = {"weight": np_weight} + + typ = relay.transform.InferType()(mod_nchw)["main"].body.checked_type + use_vm = any(isinstance(s, tvm.tir.Any) for s in typ.shape) + + if use_vm: + rt_mod, dev, num_cutlass_partition = profile_and_build_vm( + convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}), params, sm + ) + out = get_output_vm(rt_mod, ["data"], [np_data]) + else: + rt_mod, dev, num_cutlass_partition = profile_and_build( + convert_conv2d_layout(mod_nchw, {"nn.conv2d": ["NHWC", "OHWI"]}), + params, + sm, + ) + out = get_output(rt_mod, ["data"], [np_data]) + + assert num_cutlass_partition > 0 + + rt_mod_ref, _ = get_ref_rt_mod( + convert_conv2d_layout(mod_ref, {"nn.conv2d": ["NHWC", "HWIO"]}), + params, + target="cuda", + ) + ref_out = get_output(rt_mod_ref, ["data"], [np_data]) + + np.testing.assert_allclose(out, ref_out, atol=atol, rtol=rtol) + + if run_benchmark: + print("CUTLASS:", rt_mod.benchmark(dev, number=1, repeat=600)) + print("TVM Tensorcore (no tuning):", rt_mod_ref.benchmark(dev, number=1, repeat=600)) + + +def test_conv2d(): + d_shape = (16, 16, 32, 32) + w_shape = (32, 16, 3, 3) + mod_nchw = get_conv2d_nchw(d_shape, w_shape) + + verify_conv2d( + mod_nchw, + mod_nchw, + d_shape, + w_shape, + sm=80, + atol=1e-5, + rtol=1e-5, + run_benchmark=False, + ) + + dyn_batch_shape = (relay.Any(),) + d_shape[1:] + mod_dyn = get_conv2d_nchw(dyn_batch_shape, w_shape) + + verify_conv2d( + mod_dyn, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False + ) + + if __name__ == "__main__": pytest.main([__file__])