Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 4, 2022
1 parent 2191918 commit c2098e7
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 12 deletions.
104 changes: 93 additions & 11 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -318,11 +318,20 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, "int stride_w = " + attrs.at("stride_w") + ";\n");
CutlassPrint(conv2d_decl, "int dilation_h = " + attrs.at("dilation_h") + ";\n");
CutlassPrint(conv2d_decl, "int dilation_w = " + attrs.at("dilation_w") + ";\n");
// TODO
const int split_k_slices = 8;
CutlassPrint(conv2d_decl, "int split_k_slices = " + std::to_string(split_k_slices) + ";\n");

CutlassPrint(
conv2d_decl,
"cutlass::conv::Conv2dProblemSize problem_size(N, H, W, C, K, R, S, P, Q, pad_h, pad_w, "
"stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, 1);\n");
"stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, split_k_slices);\n");

const bool use_split_k = split_k_slices > 1;
const std::string split_k_mode = use_split_k > 1 ? "kParallel" : "kSerial";
CutlassPrint(conv2d_decl,
"const cutlass::conv::SplitKMode split_k_mode = cutlass::conv::SplitKMode::" +
split_k_mode + ";\n");

bool is_wgrad = op_type.find("backward_weight") != std::string::npos;
bool is_dgrad = op_type.find("conv2d_transpose") != std::string::npos;
Expand Down Expand Up @@ -372,22 +381,24 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, "TensorNHWC layout_D(output_oshape);\n\n");
}

CutlassPrint(conv2d_decl, "typename Conv2d::Arguments arguments{\n");
CutlassPrint(conv2d_decl, " problem_size,\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputA*>(ptr_a), layout_A},\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputB*>(ptr_b), layout_B},\n");

if (has_residual_block) {
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_residual), layout_C},\n");
CutlassPrint(conv2d_decl, "TensorRef tensor_c{static_cast<ElementOutput*>(ptr_residual), layout_C};\n");
} else if (has_bias) {
CutlassPrint(
conv2d_decl,
" {static_cast<ElementOutput*>(ptr_c_bias), cutlass::layout::TensorNHWC::Stride(0)},\n");
"TensorRef tensor_c{static_cast<ElementOutput*>(ptr_c_bias), cutlass::layout::TensorNHWC::Stride(0)};\n");
} else {
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out), layout_C},\n");
CutlassPrint(conv2d_decl, "TensorRef tensor_c{static_cast<ElementOutput*>(ptr_out), layout_C};\n");
}

CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_D},\n");
CutlassPrint(conv2d_decl, "TensorRef tensor_d{static_cast<ElementOutput*>(ptr_out),layout_D};\n");

CutlassPrint(conv2d_decl, "typename Conv2d::Arguments arguments{\n");
CutlassPrint(conv2d_decl, " problem_size,\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputA*>(ptr_a), layout_A},\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputB*>(ptr_b), layout_B},\n");
CutlassPrint(conv2d_decl, " tensor_c,\n");
CutlassPrint(conv2d_decl, " tensor_d},\n");

if (has_residual_block) {
CutlassPrint(conv2d_decl, "{alpha, beta},\n");
Expand All @@ -397,9 +408,11 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
} else if (has_bias && no_bias_scaling) {
CutlassPrint(conv2d_decl, " {alpha}\n};\n");
} else {
CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n");
CutlassPrint(conv2d_decl, "{alpha, beta};\n");
}

CutlassPrint(conv2d_decl, "split_k_mode\n};\n");

CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n");

CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n");
Expand All @@ -412,10 +425,77 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
// Initialize CUTLASS kernel with arguments and workspace pointer
CutlassPrint(conv2d_decl, "status = conv2d_op.initialize(arguments, workspace.get());\n");
CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");

if (use_split_k) {
CutlassPrint(conv2d_decl, "arguments.ref_D.reset(reinterpret_cast<ElementOutput*>(workspace.get())); \n");
CutlassPrint(conv2d_decl, "arguments.output_op = {ElementComputeEpilogue(1), ElementComputeEpilogue(0)}; \n");
CutlassPrint(conv2d_decl, "status = conv2d.update(arguments, workspace.get()); \n");
CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");
}

// Launch initialized CUTLASS kernel
CutlassPrint(conv2d_decl, "status = conv2d_op();\n");
CutlassPrint(conv2d_decl, "CHECK(status == cutlass::Status::kSuccess);\n");

if (use_split_k) {
CutlassPrint(conv2d_decl, "using EpilogueOutputOp = Conv2d::EpilogueOutputOp\n");
CutlassPrint(conv2d_decl, "using ReductionOp = cutlass::reduction::thread::ReduceAdd<\n");
CutlassPrint(conv2d_decl, " Conv2d::ElementAccumulator,\n");
CutlassPrint(conv2d_decl, " typename EpilogueOutputOp::ElementComputeEpilogue,\n");
CutlassPrint(conv2d_decl, " EpilogueOutputOp::kCount\n");
CutlassPrint(conv2d_decl, " >;\n");
CutlassPrint(conv2d_decl, "\n");
CutlassPrint(conv2d_decl,
"using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<\n");
CutlassPrint(conv2d_decl, " cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,\n");
CutlassPrint(conv2d_decl, " EpilogueOutputOp,\n");
CutlassPrint(conv2d_decl, " ReductionOp\n");
CutlassPrint(conv2d_decl, " >;\n");
CutlassPrint(conv2d_decl, "\n");
CutlassPrint(
conv2d_decl,
"using ReductionDevice = cutlass::reduction::device::ReduceSplitK<ReductionKernel>;\n");
CutlassPrint(conv2d_decl,
"using ReductionStrideIndex = typename ReductionDevice::StrideIndex;\n");
CutlassPrint(conv2d_decl, " ReductionDevice reduction_op;\n");
CutlassPrint(conv2d_decl,
" const static cutlass::conv::Operator kConvolutionalOperator = "
"Conv2d::kConvolutionalOperator;\n");
CutlassPrint(conv2d_decl, " typename ReductionDevice::Arguments reduction_args(\n");
CutlassPrint(conv2d_decl,
" cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, "
"problem_size).mn(),\n");
CutlassPrint(conv2d_decl, " problem_size.split_k_slices,\n");
CutlassPrint(conv2d_decl,
" cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, "
"problem_size),\n");
CutlassPrint(conv2d_decl, " {\n");
CutlassPrint(conv2d_decl, " reinterpret_cast<Conv2d::ElementAccumulator*> (workspace.get()),\n");
CutlassPrint(conv2d_decl,
" "
"ReductionStrideIndex(tensor_c.stride()[Conv2d::ImplicitGemmKernel::"
"kTensorCStrideIdx])\n");
CutlassPrint(conv2d_decl, " },\n");
CutlassPrint(conv2d_decl, " {\n");
CutlassPrint(conv2d_decl, " tensor_d.device_data(),\n");
CutlassPrint(conv2d_decl,
" "
"ReductionStrideIndex(tensor_d.stride()[Conv2d::ImplicitGemmKernel::"
"kTensorCStrideIdx])\n");
CutlassPrint(conv2d_decl, " },\n");
CutlassPrint(conv2d_decl, " {\n");
CutlassPrint(conv2d_decl, " tensor_c.device_data(),\n");
CutlassPrint(conv2d_decl,
" "
"ReductionStrideIndex(tensor_c.stride()[Conv2d::ImplicitGemmKernel::"
"kTensorCStrideIdx])\n");
CutlassPrint(conv2d_decl, " },\n");
CutlassPrint(conv2d_decl, " {alpha, beta}\n");
CutlassPrint(conv2d_decl, " );\n\n");
CutlassPrint(conv2d_decl, " status = reduction_op.initialize(reduction_args, nullptr);\n");
CutlassPrint(conv2d_decl, " status = reduction_op();\n");
}

return conv2d_decl.str();
}

Expand Down Expand Up @@ -734,6 +814,8 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase {
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_hardswish.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_residual_block.h>\n";
code_stream_ << "#include <cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h>\n";
code_stream_ << "#include <cutlass/reduction/device/reduce_split_k.h>\n";
code_stream_ << "#include <cutlass/reduction/thread/reduction_operators.h>\n";

ICHECK(ref->IsInstance<FunctionNode>());
auto res = GenCutlassFunc(Downcast<Function>(ref));
Expand Down
5 changes: 4 additions & 1 deletion tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,8 @@ def test_conv2d_backward_weight():
data_dtype=dtype,
)

# split k


def test_conv2d_bwd():
IC = 16
Expand Down Expand Up @@ -913,4 +915,5 @@ def test_conv2d_bwd():


if __name__ == "__main__":
pytest.main([__file__])
# pytest.main([__file__])
test_conv2d_backward_weight()

0 comments on commit c2098e7

Please sign in to comment.