From 084d5c47666df92ba6c2c1445d5a23de0193a119 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 2 Feb 2022 12:35:18 +0900 Subject: [PATCH] fix compile error for fprop --- src/relay/backend/contrib/cutlass/codegen.cc | 26 ++++++++++++-------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/relay/backend/contrib/cutlass/codegen.cc b/src/relay/backend/contrib/cutlass/codegen.cc index 48e4f32df606a..eafde7db6d28c 100644 --- a/src/relay/backend/contrib/cutlass/codegen.cc +++ b/src/relay/backend/contrib/cutlass/codegen.cc @@ -319,13 +319,14 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, CutlassPrint(conv2d_decl, "int dilation_h = " + attrs.at("dilation_h") + ";\n"); CutlassPrint(conv2d_decl, "int dilation_w = " + attrs.at("dilation_w") + ";\n"); // TODO - const int split_k_slices = 8; + const int split_k_slices = 1; CutlassPrint(conv2d_decl, "int split_k_slices = " + std::to_string(split_k_slices) + ";\n"); CutlassPrint( conv2d_decl, "cutlass::conv::Conv2dProblemSize problem_size(N, H, W, C, K, R, S, P, Q, pad_h, pad_w, " - "stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, split_k_slices);\n"); + "stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, " + "split_k_slices);\n"); const bool use_split_k = split_k_slices > 1; const std::string split_k_mode = use_split_k > 1 ? "kParallel" : "kSerial"; @@ -403,18 +404,19 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, CutlassPrint(conv2d_decl, " tensor_d,\n"); if (has_residual_block) { + ICHECK(use_split_k == false) << "Split-k not supported for residual block fusion"; CutlassPrint(conv2d_decl, "{alpha, beta},\n"); CutlassPrint(conv2d_decl, "cutlass::conv::SplitKMode::kSerial,\n"); // split_k_slices CutlassPrint(conv2d_decl, "static_cast(ptr_bias),\n"); CutlassPrint(conv2d_decl, "nullptr, 0, K};\n"); } else if (has_bias && no_bias_scaling) { - CutlassPrint(conv2d_decl, " {alpha}\n},\n"); + CutlassPrint(conv2d_decl, " {alpha},\n"); + CutlassPrint(conv2d_decl, "split_k_mode\n};\n"); } else { CutlassPrint(conv2d_decl, "{alpha, beta},\n"); + CutlassPrint(conv2d_decl, "split_k_mode\n};\n"); } - CutlassPrint(conv2d_decl, "split_k_mode\n};\n"); - CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n"); CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n"); @@ -429,10 +431,13 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); if (use_split_k) { - CutlassPrint(conv2d_decl, "arguments.ref_D.reset(reinterpret_cast(workspace.get())); \n"); - CutlassPrint(conv2d_decl, "arguments.output_op = {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}; \n"); + CutlassPrint(conv2d_decl, + "\narguments.ref_D.reset(reinterpret_cast(workspace.get())); \n"); + CutlassPrint( + conv2d_decl, + "arguments.output_op = {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}; \n"); CutlassPrint(conv2d_decl, "status = conv2d_op.update(arguments, workspace.get()); \n"); - CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); + CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n\n"); } // Launch initialized CUTLASS kernel @@ -440,7 +445,7 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n"); if (use_split_k) { - CutlassPrint(conv2d_decl, "using EpilogueOutputOp = Conv2d::EpilogueOutputOp;\n"); + CutlassPrint(conv2d_decl, "\nusing EpilogueOutputOp = Conv2d::EpilogueOutputOp;\n"); CutlassPrint(conv2d_decl, "using ReductionOp = cutlass::reduction::thread::ReduceAdd<\n"); CutlassPrint(conv2d_decl, " Conv2d::ElementAccumulator,\n"); CutlassPrint(conv2d_decl, " typename EpilogueOutputOp::ElementAccumulator,\n"); @@ -472,7 +477,8 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs, " cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, " "problem_size),\n"); CutlassPrint(conv2d_decl, " {\n"); - CutlassPrint(conv2d_decl, " reinterpret_cast (workspace.get()),\n"); + CutlassPrint(conv2d_decl, + " reinterpret_cast (workspace.get()),\n"); CutlassPrint(conv2d_decl, " " "ReductionStrideIndex(tensor_c.stride()[Conv2d::ImplicitGemmKernel::"