diff --git a/python/tvm/contrib/cutlass/build.py b/python/tvm/contrib/cutlass/build.py index 918eeaf68b75..06c33f2f7ae0 100644 --- a/python/tvm/contrib/cutlass/build.py +++ b/python/tvm/contrib/cutlass/build.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name, dangerous-default-value """Driver for partitioning and building a Relay module for CUTLASS offload.""" import logging import os @@ -238,6 +238,7 @@ def handle_conv2d( data_dtype, weight_dtype, use_3xtf32, + split_k_slices, profile_all_alignments, find_first_valid, use_multiprocessing, @@ -269,6 +270,7 @@ def handle_conv2d( weight_dtype, use_3xtf32, conv_kind, + split_k_slices, profile_all_alignments, find_first_valid=find_first_valid, use_multiprocessing=use_multiprocessing, @@ -288,6 +290,7 @@ def tune_cutlass_kernels( mod, sm, use_3xtf32=True, + split_k_slices=[1], profile_all_alignments=False, find_first_valid=False, use_multiprocessing=False, @@ -309,6 +312,14 @@ def tune_cutlass_kernels( Wheter or not use slower but very accurate (compared to tf32) 3xtf32 mode for fp32 inputs on tensorcore. + split_k_slices : list of int + Split factor candidates for split-K GEMM. If split-K > 1, the GEMM K-loop is computed in + parallel accross split-K blocks, and a seperate global reduction kernel is launched to + accumulate partial reductions. The profiler will pick the best split-k factor from the + given candidate list. Note that the larger split-K factor requires a larger workspace. + Currently, parallel split-k has been tested only for wgrad. For GEMM and other conv2d + kinds, split_k_slices is ignored. + profile_all_alignments : bool When True, profile all kernal variants with smaller alignments than the largest possible. @@ -380,6 +391,7 @@ def tune_cutlass_kernels( arg0_dtype, arg1_dtype, use_3xtf32, + split_k_slices, profile_all_alignments, find_first_valid, use_multiprocessing, diff --git a/python/tvm/contrib/cutlass/conv2d_operation.py b/python/tvm/contrib/cutlass/conv2d_operation.py index 5318cc7d74c4..7b78c5a375d2 100644 --- a/python/tvm/contrib/cutlass/conv2d_operation.py +++ b/python/tvm/contrib/cutlass/conv2d_operation.py @@ -35,6 +35,7 @@ def __init__( stride_support, epilogue_functor=EpilogueFunctor.LinearCombination, swizzling_functor=SwizzlingFunctor.Identity1, + split_k_slices=1, ): self.operation_kind = OperationKind.Conv2d self.arch = arch @@ -48,6 +49,7 @@ def __init__( self.iterator_algorithm = iterator_algorithm self.stride_support = stride_support self.swizzling_functor = swizzling_functor + self.split_k_slices = split_k_slices def accumulator_type(self): return self.tile_description.math_instruction.element_accumulator @@ -127,6 +129,9 @@ def procedural_name(self): "_${layout}_align${alignment}" ) + if self.split_k_slices > 1: + configuration_name += "_splitk%d" % self.split_k_slices + return substitute_template( configuration_name, { @@ -172,6 +177,14 @@ def __init__(self): ${unary_op} >""" + self.epilogue_wgrad = """ + ${epilogue_functor}< + ${element_c}, + 4, + float, + float + >""" + self.template = """ // Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}" using ${operation_name} = @@ -197,9 +210,31 @@ def __init__(self): ${align_a}, ${align_b} >::Kernel; + + ${reduction} +""" + + self.reduction_template = """ +using EpilogueOutputOp = ${epilogue}; +using ReductionOp = cutlass::reduction::thread::ReduceAdd< + ${element_accumulator}, + ${element_accumulator}, + EpilogueOutputOp::kCount + >; + +using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK< + cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>, + EpilogueOutputOp, + ReductionOp + >; + +using ReductionDevice = cutlass::reduction::device::ReduceSplitK; +using ReductionStrideIndex = typename ReductionDevice::StrideIndex; """ - def emit(self, operation, no_beta_scaling=False, residual_block_info=False): + def emit( + self, operation, no_beta_scaling=False, residual_block_info=False, emit_reduction=False + ): """Instantiate a Conv2d kernel from given `operation`.""" warp_shape = [ int( @@ -214,6 +249,31 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False): / DataTypeSize[operation.C.element] ) + element_c = operation.C.element + use_split_k_wgrad = operation.conv_kind == ConvKind.Wgrad and operation.split_k_slices > 1 + # Gemm output always fp32 in wgrad with split k + element_c_gemm = DataType.f32 if use_split_k_wgrad else element_c + + if emit_reduction: + epilogue_reduction = substitute_template( + self.epilogue_wgrad, + { + "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], + "element_c": DataTypeTag[element_c], + }, + ) + reduction = substitute_template( + self.reduction_template, + { + "epilogue": epilogue_reduction, + "operation_name": operation.procedural_name(), + "element_accumulator": DataTypeTag[operation.accumulator_type()], + }, + ) + gemm_template = substitute_template(self.template, {"reduction": reduction}) + else: + gemm_template = substitute_template(self.template, {"reduction": ""}) + values = { "operation_name": operation.procedural_name(), "conv_kind": ConvKindTag[operation.conv_kind], @@ -222,7 +282,7 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False): "layout_a": LayoutTag[operation.A.layout], "element_b": DataTypeTag[operation.B.element], "layout_b": LayoutTag[operation.B.layout], - "element_c": DataTypeTag[operation.C.element], + "element_c": DataTypeTag[element_c_gemm], "layout_c": LayoutTag[operation.C.layout], "element_accumulator": DataTypeTag[operation.accumulator_type()], "opcode_class": OpcodeClassTag[ @@ -262,9 +322,19 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False): "conv_kernel_postfix": "", } - if residual_block_info: + if use_split_k_wgrad: + # Even if the output is fp16, gemm output is always fp32 for split k wgrad. + epilogue_gemm = substitute_template( + self.epilogue_wgrad, + { + "epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor], + "element_c": "float", + }, + ) + template = substitute_template(gemm_template, {"epilogue": epilogue_gemm}) + elif residual_block_info: template = substitute_template( - self.template, {"epilogue": self.epilogue_residual_block} + gemm_template, {"epilogue": self.epilogue_residual_block} ) values.update( { @@ -276,9 +346,9 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False): ) elif no_beta_scaling: template = substitute_template( - self.template, {"epilogue": self.epilogue_no_beta_scaling} + gemm_template, {"epilogue": self.epilogue_no_beta_scaling} ) else: - template = substitute_template(self.template, {"epilogue": self.epilogue_default}) + template = substitute_template(gemm_template, {"epilogue": self.epilogue_default}) return substitute_template(template, values) diff --git a/python/tvm/contrib/cutlass/conv2d_profiler.py b/python/tvm/contrib/cutlass/conv2d_profiler.py index 2f4e76943c64..1ed5550e0a66 100644 --- a/python/tvm/contrib/cutlass/conv2d_profiler.py +++ b/python/tvm/contrib/cutlass/conv2d_profiler.py @@ -17,6 +17,8 @@ # pylint: disable=import-outside-toplevel, invalid-name """Instantiate a C++ source for profiling CUTLASS kernels.""" +from .library import DataTypeTag + class Conv2dProfilerEmitter(object): """Emit a C++ source for profiling CUTLASS kernels.""" @@ -24,6 +26,32 @@ class Conv2dProfilerEmitter(object): def __init__(self): from jinja2 import Template + self.reduction = """ + ReductionDevice reduction_op; + static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemm::kConvolutionalOperator; + typename ReductionDevice::Arguments reduction_args( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(), + problem_size.split_k_slices, + cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size), + { + reinterpret_cast (workspace.get()), + ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx]) + }, + { + tensor_d.device_data(), + ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx]) + }, + { + tensor_c.device_data(), + ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx]) + }, + {ElementComputeEpilogue(1), ElementComputeEpilogue(0)} + ); + + reduction_op.initialize(reduction_args, nullptr); + reduction_op(); +""" + self.template = Template( """ #include @@ -35,6 +63,8 @@ def __init__(self): #include "cutlass/util/command_line.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/reduction/device/reduce_split_k.h" +#include "cutlass/reduction/thread/reduction_operators.h" #define CUTLASS_CHECK(status) \ { \ @@ -88,10 +118,11 @@ def __init__(self): }; double profile_convolution(Options const &options) { - using ElementOutput = typename ImplicitGemm::ElementC; + using ElementOutput = {{ElementOutput}}; using ElementInputA = typename ImplicitGemm::ElementA; using ElementInputB = typename ImplicitGemm::ElementB; + int split_k_slices = {{SplitK}}; cutlass::conv::Conv2dProblemSize problem_size( options.input_size, options.filter_size, @@ -100,7 +131,7 @@ def __init__(self): options.dilation, options.output_size(), cutlass::conv::Mode::kCrossCorrelation, - 1 + split_k_slices ); auto conv_kind = ImplicitGemm::kConvolutionalOperator; @@ -108,20 +139,26 @@ def __init__(self): auto b_extent = implicit_gemm_tensor_b_extent(conv_kind, problem_size); auto c_extent = implicit_gemm_tensor_c_extent(conv_kind, problem_size); + using LayoutC = typename ImplicitGemm::LayoutC; cutlass::HostTensor tensor_a(a_extent); cutlass::HostTensor tensor_b(b_extent); cutlass::HostTensor tensor_c(c_extent); - cutlass::HostTensor tensor_ref_c(c_extent); + cutlass::HostTensor tensor_d(c_extent); + cutlass::HostTensor tensor_c_gemm(c_extent); using ElementComputeEpilogue = typename ImplicitGemm::ElementCompute; + cutlass::conv::SplitKMode const split_k_mode = split_k_slices > 1 ? + cutlass::conv::SplitKMode::kParallel : cutlass::conv::SplitKMode::kSerial; + typename ImplicitGemm::Arguments arguments{ problem_size, tensor_a.device_ref(), tensor_b.device_ref(), - tensor_c.device_ref(), - tensor_c.device_ref(), + tensor_c_gemm.device_ref(), + tensor_c_gemm.device_ref(), {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}, + split_k_mode, }; ImplicitGemm implicit_gemm_op; @@ -144,6 +181,7 @@ def __init__(self): for (int iteration = 0; iteration < 100; ++iteration) { auto status = implicit_gemm_op(); CUTLASS_CHECK(status); + {{Reduction}} } cudaEventRecord(events[1]); @@ -166,6 +204,12 @@ def __init__(self): """ ) - def emit(self, op_def, op_name): - src = self.template.render(OperatorDef=op_def, OperatorName=op_name) + def emit(self, op_def, op_name, element_output, split_k_slices=1): + src = self.template.render( + OperatorDef=op_def, + OperatorName=op_name, + ElementOutput=DataTypeTag[element_output], + SplitK=split_k_slices, + Reduction=self.reduction if split_k_slices > 1 else "", + ) return src diff --git a/python/tvm/contrib/cutlass/gen_conv2d.py b/python/tvm/contrib/cutlass/gen_conv2d.py index 0d46000bab70..b51afdc8b586 100644 --- a/python/tvm/contrib/cutlass/gen_conv2d.py +++ b/python/tvm/contrib/cutlass/gen_conv2d.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name, dangerous-default-value """Conv2d kernel generator and profiler for CUTLASS.""" from functools import partial from .conv2d_operation import Conv2dOperation, EmitConv2dInstance @@ -40,6 +40,7 @@ def create_conv2d_operator_with_epilogue( data_type, alignment, swizzling_functor, + split_k_slices, ): """ Instantiate a cutlass kernel from the given configuration, @@ -90,11 +91,15 @@ def create_conv2d_operator_with_epilogue( stride_support, epilogue, swizzling_functor, + split_k_slices, ) name = op.procedural_name() opdef = EmitConv2dInstance().emit( - op, no_beta_scaling=no_beta_scaling, residual_block_info=residual_block_info + op, + no_beta_scaling=no_beta_scaling, + residual_block_info=residual_block_info, + emit_reduction=split_k_slices > 1, ) return name, opdef @@ -103,6 +108,7 @@ def create_conv2d_operator_with_epilogue( def enumerate_conv2d_operators( conv_kind, stride_support, + split_k_slices, tile_descriptions, data_type, alignment_constraints, @@ -119,37 +125,45 @@ def enumerate_conv2d_operators( if conv_kind == ConvKind.Dgrad and stride_support == StrideSupport.Strided: swizzling_functor = SwizzlingFunctor.StridedDgradIdentity1 - for tile in tile_descriptions: - for alignment in alignment_constraints: - - A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment) - B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment) - C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment) - - op = Conv2dOperation( - conv_kind, - IteratorAlgorithm.Optimized, - tile.minimum_compute_capability, - tile, - A, - B, - C, - element_epilogue, - stride_support, - EpilogueFunctor.LinearCombination, - swizzling_functor, - ) - - ret.append( - { - "src": profiler_emitter.emit(kernel_emitter.emit(op), op.procedural_name()), - "name": op.procedural_name(), - "tile_description": tile, - "alignment": alignment, - "data_type": data_type, - "swizzle_functor": swizzling_functor, - } - ) + for split_k_slice in split_k_slices: + for tile in tile_descriptions: + for alignment in alignment_constraints: + + A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment) + B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment) + C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment) + + op = Conv2dOperation( + conv_kind, + IteratorAlgorithm.Optimized, + tile.minimum_compute_capability, + tile, + A, + B, + C, + element_epilogue, + stride_support, + EpilogueFunctor.LinearCombination, + swizzling_functor, + split_k_slice, + ) + + ret.append( + { + "src": profiler_emitter.emit( + kernel_emitter.emit(op, emit_reduction=split_k_slice > 1), + op.procedural_name(), + element_output=element_c, + split_k_slices=split_k_slice, + ), + "name": op.procedural_name(), + "tile_description": tile, + "alignment": alignment, + "data_type": data_type, + "swizzle_functor": swizzling_functor, + "split_k_slices": split_k_slice, + } + ) return ret @@ -198,6 +212,7 @@ def get_default( data_type, alignment, swizzling_functor, + split_k_slices=1, ) return {"name": name, "opdef": opdef} @@ -214,6 +229,7 @@ def select_op( use_3xtf32, conv_kind, stride_support, + split_k_slices, profile_all_alignments=False, find_first_valid=False, use_multiprocessing=False, @@ -248,7 +264,7 @@ def select_op( out_dtype, data_dtype, weight_dtype, - partial(enumerate_conv2d_operators, conv_kind, stride_support), + partial(enumerate_conv2d_operators, conv_kind, stride_support, split_k_slices), lambda align: all([dim % align == 0 for dim in [IC, OC]]), use_3xtf32, profile_all_alignments, @@ -288,6 +304,7 @@ def profile( weight_dtype, use_3xtf32=True, conv_kind=ConvKind.Fprop, + split_k_slices=[1], profile_all_alignments=False, find_first_valid=False, use_multiprocessing=False, @@ -315,6 +332,7 @@ def profile( use_3xtf32, conv_kind, stride_support, + split_k_slices, profile_all_alignments, find_first_valid, use_multiprocessing, @@ -328,6 +346,7 @@ def profile( op["data_type"], op["alignment"], op["swizzle_functor"], + op["split_k_slices"], ) return name, opdef, op["runtime"] diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index 76d43834c7ae..b3f40f09419c 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -363,6 +363,9 @@ def evaluate(self, op, args): try: sp = subprocess.run(cmd, capture_output=True, check=True) rt = float(sp.stdout) + if rt == 0.0: + # This seems to happen with split-k using invalid split-k-slices + rt = float("inf") logger.info("%s, %f", op_name, rt) except subprocess.CalledProcessError: rt = float("inf") diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index fdd268d1d9d1..b12da1ac62cb 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -284,15 +284,15 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, op_type != "cutlass.conv2d_bias_silu" && op_type != "cutlass.conv2d_bias_hardswish"; + const std::string op_name = attrs.at("op_name"); std::ostringstream conv2d_decl; CutlassPrint(conv2d_decl, attrs.at("op_def")); - CutlassPrint(conv2d_decl, "using Operation_" + attrs.at("op_name") + - " = cutlass::conv::device::ImplicitGemmConvolution<" + - attrs.at("op_name") + ">;\n"); - CutlassPrint(conv2d_decl, "using Conv2d = Operation_" + attrs.at("op_name") + ";\n"); + CutlassPrint(conv2d_decl, "using Operation_" + op_name + + " = cutlass::conv::device::ImplicitGemmConvolution<" + op_name + + ">;\n"); + CutlassPrint(conv2d_decl, "using Conv2d = Operation_" + op_name + ";\n"); CutlassPrint(conv2d_decl, "using ElementInputA = Conv2d::ElementA;\n"); CutlassPrint(conv2d_decl, "using ElementInputB = Conv2d::ElementB;\n"); - CutlassPrint(conv2d_decl, "using ElementOutput = Conv2d::ElementC;\n"); CutlassPrint(conv2d_decl, "using ElementComputeEpilogue = Conv2d::ElementAccumulator;\n"); auto get_dim = [&attrs](const std::string& axis, const std::string& var_name, int axis_idx) { @@ -319,10 +319,25 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, CutlassPrint(conv2d_decl, "int dilation_h = " + attrs.at("dilation_h") + ";\n"); CutlassPrint(conv2d_decl, "int dilation_w = " + attrs.at("dilation_w") + ";\n"); + const bool use_split_k = op_name.find("splitk") != std::string::npos; + + if (use_split_k) { + std::string split_k_slices = op_name.substr(op_name.find_last_not_of("0123456789") + 1); + CutlassPrint(conv2d_decl, "int split_k_slices = " + split_k_slices + ";\n"); + } else { + CutlassPrint(conv2d_decl, "int split_k_slices = 1;\n"); + } + CutlassPrint( conv2d_decl, "cutlass::conv::Conv2dProblemSize problem_size(N, H, W, C, K, R, S, P, Q, pad_h, pad_w, " - "stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, 1);\n"); + "stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, " + "split_k_slices);\n"); + + const std::string split_k_mode = use_split_k ? "kParallel" : "kSerial"; + CutlassPrint(conv2d_decl, + "const cutlass::conv::SplitKMode split_k_mode = cutlass::conv::SplitKMode::" + + split_k_mode + ";\n"); bool is_wgrad = op_type.find("backward_weight") != std::string::npos; bool is_dgrad = op_type.find("conv2d_transpose") != std::string::npos; @@ -372,32 +387,51 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, CutlassPrint(conv2d_decl, "TensorNHWC layout_D(output_oshape);\n\n"); } + if (use_split_k) { + CutlassPrint(conv2d_decl, "using ElementOutput = EpilogueOutputOp::ElementOutput;\n"); + } else { + CutlassPrint(conv2d_decl, "using ElementOutput = Conv2d::ElementC;\n"); + } + + std::string tensor_c_init = "{static_cast(ptr_out), layout_C}"; + if (has_residual_block) { + tensor_c_init = "{static_cast(ptr_residual), layout_C}"; + } else if (has_bias) { + tensor_c_init = + "{static_cast(ptr_c_bias), cutlass::layout::TensorNHWC::Stride(0)}"; + } + + CutlassPrint(conv2d_decl, + "cutlass::TensorRef tensor_c" + tensor_c_init + ";\n"); + CutlassPrint(conv2d_decl, + "cutlass::TensorRef " + "tensor_d{static_cast(ptr_out),layout_D};\n"); + CutlassPrint(conv2d_decl, "typename Conv2d::Arguments arguments{\n"); CutlassPrint(conv2d_decl, " problem_size,\n"); CutlassPrint(conv2d_decl, " {static_cast(ptr_a), layout_A},\n"); CutlassPrint(conv2d_decl, " {static_cast(ptr_b), layout_B},\n"); - if (has_residual_block) { - CutlassPrint(conv2d_decl, " {static_cast(ptr_residual), layout_C},\n"); - } else if (has_bias) { - CutlassPrint( - conv2d_decl, - " {static_cast(ptr_c_bias), cutlass::layout::TensorNHWC::Stride(0)},\n"); + if (use_split_k) { + CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n"); + CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n"); } else { - CutlassPrint(conv2d_decl, " {static_cast(ptr_out), layout_C},\n"); + CutlassPrint(conv2d_decl, " tensor_c,\n"); + CutlassPrint(conv2d_decl, " tensor_d,\n"); } - CutlassPrint(conv2d_decl, " {static_cast(ptr_out),layout_D},\n"); - if (has_residual_block) { + ICHECK(use_split_k == false) << "Split-k not supported for residual block fusion"; CutlassPrint(conv2d_decl, "{alpha, beta},\n"); CutlassPrint(conv2d_decl, "cutlass::conv::SplitKMode::kSerial,\n"); // split_k_slices CutlassPrint(conv2d_decl, "static_cast(ptr_bias),\n"); CutlassPrint(conv2d_decl, "nullptr, 0, K};\n"); } else if (has_bias && no_bias_scaling) { - CutlassPrint(conv2d_decl, " {alpha}\n};\n"); + CutlassPrint(conv2d_decl, " {alpha},\n"); + CutlassPrint(conv2d_decl, "split_k_mode\n};\n"); } else { - CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n"); + CutlassPrint(conv2d_decl, "{alpha, beta},\n"); + CutlassPrint(conv2d_decl, "split_k_mode\n};\n"); } CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n"); @@ -408,13 +442,67 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, "cutlass::device_memory::allocation workspace(workspace_size);\n"); // Check the problem size is supported or not CutlassPrint(conv2d_decl, "cutlass::Status status = conv2d_op.can_implement(arguments);\n"); - CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); + CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n"); + + if (use_split_k) { + CutlassPrint(conv2d_decl, + "arguments.ref_D.reset(reinterpret_cast(workspace.get())," + " layout_D);\n\n"); + } + // Initialize CUTLASS kernel with arguments and workspace pointer CutlassPrint(conv2d_decl, "status = conv2d_op.initialize(arguments, workspace.get());\n"); - CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); + CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n"); + + if (use_split_k) { + CutlassPrint( + conv2d_decl, + "arguments.output_op = {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}; \n"); + CutlassPrint(conv2d_decl, "status = conv2d_op.update(arguments, workspace.get()); \n"); + CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n"); + } + // Launch initialized CUTLASS kernel CutlassPrint(conv2d_decl, "status = conv2d_op();\n"); - CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); + CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n"); + + if (use_split_k) { + CutlassPrint(conv2d_decl, "ReductionDevice reduction_op;\n"); + CutlassPrint(conv2d_decl, + "const static cutlass::conv::Operator kConvolutionalOperator = " + "Conv2d::kConvolutionalOperator;\n"); + CutlassPrint(conv2d_decl, "typename ReductionDevice::Arguments reduction_args(\n"); + CutlassPrint(conv2d_decl, + "cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, " + "problem_size).mn(),\n"); + CutlassPrint(conv2d_decl, "problem_size.split_k_slices,\n"); + CutlassPrint(conv2d_decl, + "cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, " + "problem_size),\n"); + CutlassPrint(conv2d_decl, "{\n"); + CutlassPrint(conv2d_decl, + " reinterpret_cast (workspace.get()),\n"); + CutlassPrint(conv2d_decl, + "ReductionStrideIndex(tensor_c.stride()[Conv2d::ImplicitGemmKernel::" + "kTensorCStrideIdx])\n"); + CutlassPrint(conv2d_decl, "},\n"); + CutlassPrint(conv2d_decl, "{\n"); + CutlassPrint(conv2d_decl, "tensor_d.data(),\n"); + CutlassPrint(conv2d_decl, + "ReductionStrideIndex(tensor_d.stride()[Conv2d::ImplicitGemmKernel::" + "kTensorCStrideIdx])\n"); + CutlassPrint(conv2d_decl, "},\n"); + CutlassPrint(conv2d_decl, "{\n"); + CutlassPrint(conv2d_decl, "tensor_c.data(),\n"); + CutlassPrint(conv2d_decl, + "ReductionStrideIndex(tensor_c.stride()[Conv2d::ImplicitGemmKernel::" + "kTensorCStrideIdx])\n"); + CutlassPrint(conv2d_decl, "},\n"); + CutlassPrint(conv2d_decl, " {alpha, beta}\n"); + CutlassPrint(conv2d_decl, ");\n\n"); + CutlassPrint(conv2d_decl, "status = reduction_op.initialize(reduction_args, nullptr);\n"); + CutlassPrint(conv2d_decl, "status = reduction_op();\n"); + } return conv2d_decl.str(); } @@ -720,6 +808,7 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase { code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; + code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; @@ -734,6 +823,8 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase { code_stream_ << "#include \n"; code_stream_ << "#include \n"; code_stream_ << "#include \n"; + code_stream_ << "#include \n"; + code_stream_ << "#include \n"; ICHECK(ref->IsInstance()); auto res = GenCutlassFunc(Downcast(ref)); diff --git a/tests/python/contrib/test_cutlass.py b/tests/python/contrib/test_cutlass.py index ef55c74dc3a5..ad75e73b26fc 100644 --- a/tests/python/contrib/test_cutlass.py +++ b/tests/python/contrib/test_cutlass.py @@ -253,16 +253,24 @@ def get_random_ndarray(shape, dtype): def profile_and_build( - mod, params, sm, tmp_dir="./tmp", lib_path="compile.so", use_fast_math=False, use_3xtf32=True + mod, + params, + sm, + split_k_slices=[1], + tmp_dir="./tmp", + lib_path="compile.so", + use_fast_math=False, + use_3xtf32=True, ): mod = partition_for_cutlass(mod) mod, num_cutlass_partition = tune_cutlass_kernels( mod, sm, use_3xtf32=use_3xtf32, + split_k_slices=split_k_slices, profile_all_alignments=False, find_first_valid=True, - use_multiprocessing=False, + use_multiprocessing=True, tmp_dir=tmp_dir, ) with tvm.transform.PassContext(opt_level=3): @@ -277,6 +285,7 @@ def profile_and_build_vm( mod, params, sm, + split_k_slices=[1], tmp_dir="./tmp", lib_path="compile.so", vmcode_path="vmcode.ro", @@ -287,6 +296,7 @@ def profile_and_build_vm( mod, num_cutlass_partition = tune_cutlass_kernels( mod, sm, + split_k_slices=split_k_slices, use_3xtf32=use_3xtf32, profile_all_alignments=False, find_first_valid=True, @@ -508,6 +518,7 @@ def verify_conv2d_common( inputs, params, sm=80, + split_k_slices=[1], atol=1e-5, rtol=1e-5, use_cudnn_ref=False, @@ -543,7 +554,7 @@ def verify_conv2d_common( ) rt_mod, _, num_cutlass_partition = profile_and_build_func( - mod_weight_ohwi, params, sm, use_fast_math=use_fast_math + mod_weight_ohwi, params, sm, split_k_slices, use_fast_math=use_fast_math ) out = get_output_func(rt_mod, input_names, inputs) @@ -597,6 +608,8 @@ def verify_conv2d( np_bias = get_random_ndarray((w_shape[0],), typ.dtype) params = {"weight": np_weight, "bias": np_bias} + split_k_slices = [1] + return verify_conv2d_common( expr_nchw, expr_ref, @@ -604,6 +617,7 @@ def verify_conv2d( [np_data], params, sm, + split_k_slices, atol, rtol, use_cudnn_ref, @@ -620,6 +634,7 @@ def verify_conv2d_backward_weight( grad_shape, data_shape, sm=80, + split_k_slices=[1], atol=1e-5, rtol=1e-5, use_cudnn_ref=False, @@ -640,6 +655,7 @@ def verify_conv2d_backward_weight( [np_grad, np_data], params, sm, + split_k_slices, atol, rtol, use_cudnn_ref, @@ -838,18 +854,20 @@ def test_conv2d_backward_weight(): weight_dtype=dtype, ) - verify_conv2d_backward_weight( - mod_nchw, - mod_nchw, - o_shape, - d_shape, - sm=80, - atol=1e-3, - rtol=1e-3, - use_cudnn_ref=False, - grad_dtype=dtype, - data_dtype=dtype, - ) + for split_k_slices in [1, 8]: + verify_conv2d_backward_weight( + mod_nchw, + mod_nchw, + o_shape, + d_shape, + sm=80, + split_k_slices=[split_k_slices], + atol=1e-3, + rtol=1e-3, + use_cudnn_ref=False, + grad_dtype=dtype, + data_dtype=dtype, + ) def test_conv2d_bwd():