From 4a383e2c7c37148a563e9cf34968fb7da3aaf91f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 3 Feb 2022 14:05:24 +0900 Subject: [PATCH] update c++ codegen --- src/relay/backend/contrib/cutlass/codegen.cc | 38 ++++++++------------ 1 file changed, 15 insertions(+), 23 deletions(-) diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index c55857c4a685d..1469a7e1e3836 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -400,8 +400,14 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, CutlassPrint(conv2d_decl, " problem_size,\n"); CutlassPrint(conv2d_decl, " {static_cast(ptr_a), layout_A},\n"); CutlassPrint(conv2d_decl, " {static_cast(ptr_b), layout_B},\n"); - CutlassPrint(conv2d_decl, " tensor_c,\n"); - CutlassPrint(conv2d_decl, " tensor_d,\n"); + + if (use_split_k) { + CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n"); + CutlassPrint(conv2d_decl, "{nullptr, TensorNHWC()},\n"); + } else { + CutlassPrint(conv2d_decl, " tensor_c,\n"); + CutlassPrint(conv2d_decl, " tensor_d,\n"); + } if (has_residual_block) { ICHECK(use_split_k == false) << "Split-k not supported for residual block fusion"; @@ -426,13 +432,18 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, // Check the problem size is supported or not CutlassPrint(conv2d_decl, "cutlass::Status status = conv2d_op.can_implement(arguments);\n"); CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); + + if (use_split_k) { + CutlassPrint( + conv2d_decl, + "arguments.ref_D.reset(reinterpret_cast(workspace.get()), layout_D);\n"); + } + // Initialize CUTLASS kernel with arguments and workspace pointer CutlassPrint(conv2d_decl, "status = conv2d_op.initialize(arguments, workspace.get());\n"); CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); if (use_split_k) { - CutlassPrint(conv2d_decl, - "\narguments.ref_D.reset(reinterpret_cast(workspace.get())); \n"); CutlassPrint( conv2d_decl, "arguments.output_op = {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}; \n"); @@ -445,25 +456,6 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); if (use_split_k) { - CutlassPrint(conv2d_decl, "\nusing EpilogueOutputOp = Conv2d::EpilogueOutputOp;\n"); - CutlassPrint(conv2d_decl, "using ReductionOp = cutlass::reduction::thread::ReduceAdd<\n"); - CutlassPrint(conv2d_decl, " Conv2d::ElementAccumulator,\n"); - CutlassPrint(conv2d_decl, " typename EpilogueOutputOp::ElementAccumulator,\n"); - CutlassPrint(conv2d_decl, " EpilogueOutputOp::kCount\n"); - CutlassPrint(conv2d_decl, " >;\n"); - CutlassPrint(conv2d_decl, "\n"); - CutlassPrint(conv2d_decl, - "using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<\n"); - CutlassPrint(conv2d_decl, " cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,\n"); - CutlassPrint(conv2d_decl, " EpilogueOutputOp,\n"); - CutlassPrint(conv2d_decl, " ReductionOp\n"); - CutlassPrint(conv2d_decl, " >;\n"); - CutlassPrint(conv2d_decl, "\n"); - CutlassPrint( - conv2d_decl, - "using ReductionDevice = cutlass::reduction::device::ReduceSplitK;\n"); - CutlassPrint(conv2d_decl, - "using ReductionStrideIndex = typename ReductionDevice::StrideIndex;\n"); CutlassPrint(conv2d_decl, " ReductionDevice reduction_op;\n"); CutlassPrint(conv2d_decl, " const static cutlass::conv::Operator kConvolutionalOperator = "