Skip to content

Commit

Permalink
[CUTLASS] Residual connection fusion (#9820)
Browse files Browse the repository at this point in the history
* [CUTLASS] Support residual block fusion for conv2d

commit d4a78a3
Author: Masahiro Masuda <[email protected]>
Date:   Thu Dec 23 16:33:41 2021 +0900

    fixed residual block check condition

commit 6ee5a39
Author: Masahiro Masuda <[email protected]>
Date:   Thu Dec 23 16:25:04 2021 +0900

    minor fix

commit 8af8b30
Author: Masahiro Masuda <[email protected]>
Date:   Thu Dec 23 16:18:50 2021 +0900

    remove SimplifyExpr pass

commit 20ae2d8
Author: Masahiro Masuda <[email protected]>
Date:   Thu Dec 23 16:16:46 2021 +0900

    fix bad merge

commit 17eed22
Author: Masahiro Masuda <[email protected]>
Date:   Thu Dec 23 16:13:53 2021 +0900

    black

commit fda151b
Author: Masahiro Masuda <[email protected]>
Date:   Thu Dec 23 16:09:45 2021 +0900

    Support residual block fusion

commit ce9d52f
Author: Masahiro Masuda <[email protected]>
Date:   Thu Dec 23 15:56:32 2021 +0900

    Remove SimplifyExpr pass from the pipeline (makes DETR result nan)

commit d3b681d
Author: Masahiro Masuda <[email protected]>
Date:   Thu Dec 23 15:47:07 2021 +0900

    fix no_beta_scaling values

commit 87b36db
Author: Masahiro Masuda <[email protected]>
Date:   Thu Dec 23 14:59:40 2021 +0900

    fill in TODO doc

commit fd67595
Author: Masahiro Masuda <[email protected]>
Date:   Thu Dec 23 14:31:06 2021 +0900

    Refactor cutlass kernel generation and selection

* do not try to support broadcast binary op

* add comments

* remove residual input shape check
  • Loading branch information
masahi authored Jan 3, 2022
1 parent 11379f7 commit e7f3648
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 29 deletions.
45 changes: 35 additions & 10 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(self):
${element_accumulator},
${element_epilogue}
>"""

self.epilogue_no_beta_scaling = """
${epilogue_functor}<
${element_c},
Expand All @@ -159,10 +160,22 @@ def __init__(self):
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>"""

self.epilogue_residual_block = """
${epilogue_functor}<
${element_c},
${element_accumulator},
${element_epilogue},
${element_c},
${epilogue_vector_length},
${activation},
${binary_op},
${unary_op}
>"""

self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name} =
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}${conv_kernel_postfix}<
${element_a},
${layout_a},
${element_b},
Expand All @@ -186,7 +199,7 @@ def __init__(self):
>::Kernel;
"""

def emit(self, operation, no_beta_scaling=False):
def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
"""Instantiate a Conv2d kernel from given `operation`."""
warp_shape = [
int(
Expand Down Expand Up @@ -246,14 +259,26 @@ def emit(self, operation, no_beta_scaling=False):
],
"align_a": str(operation.A.alignment),
"align_b": str(operation.B.alignment),
"conv_kernel_postfix": "",
}

template = substitute_template(
self.template,
{
"epilogue": self.epilogue_no_beta_scaling
if no_beta_scaling
else self.epilogue_default
},
)
if residual_block_info:
template = substitute_template(
self.template, {"epilogue": self.epilogue_residual_block}
)
values.update(
{
"unary_op": residual_block_info["unary_op"],
"binary_op": residual_block_info["binary_op"],
"activation": residual_block_info["activation"],
"conv_kernel_postfix": "WithBroadcast",
}
)
elif no_beta_scaling:
template = substitute_template(
self.template, {"epilogue": self.epilogue_no_beta_scaling}
)
else:
template = substitute_template(self.template, {"epilogue": self.epilogue_default})

return substitute_template(template, values)
31 changes: 29 additions & 2 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,32 @@ def create_conv2d_operator_with_epilogue(
Instantiate a cutlass kernel from the given configuration,
along with the epilouge functor
"""
epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]
if "residual" in op_type:
activation_map = {
"cutlass.conv2d_bias_hardswish": "cutlass::epilogue::thread::HardSwish",
"cutlass.conv2d_bias_silu": "cutlass::epilogue::thread::SiLu",
"cutlass.conv2d_bias_sigmoid": "cutlass::epilogue::thread::Sigmoid",
"cutlass.conv2d_bias_relu": "cutlass::epilogue::thread::ReLu",
"cutlass.conv2d_bias": "cutlass::epilogue::thread::Identity",
}
prefix = op_type[: op_type.find("_residual")]
activation = activation_map[prefix]
binary_op = "cutlass::multiplies" if "residual_multiply" in op_type else "cutlass::plus"
unary_op = (
"cutlass::epilogue::thread::ReLu"
if op_type.endswith("relu")
else "cutlass::epilogue::thread::Identity"
)
residual_block_info = {
"activation": activation,
"binary_op": binary_op,
"unary_op": unary_op,
}
epilogue = EpilogueFunctor.LinearCombinationResidualBlock
no_beta_scaling = False
else:
residual_block_info = None
epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]

element_a, element_b, element_c, element_epilogue = data_type

Expand All @@ -62,7 +87,9 @@ def create_conv2d_operator_with_epilogue(
)

name = op.procedural_name()
opdef = EmitConv2dInstance().emit(op, no_beta_scaling=no_beta_scaling)
opdef = EmitConv2dInstance().emit(
op, no_beta_scaling=no_beta_scaling, residual_block_info=residual_block_info
)

return name, opdef

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class EpilogueFunctor(enum.Enum):
LinearCombinationSigmoid = enum_auto()
LinearCombinationSilu = enum_auto()
LinearCombinationHardSwish = enum_auto()
LinearCombinationResidualBlock = enum_auto()


EpilogueFunctorTag = {
Expand All @@ -161,6 +162,7 @@ class EpilogueFunctor(enum.Enum):
EpilogueFunctor.LinearCombinationSigmoid: "cutlass::epilogue::thread::LinearCombinationSigmoid",
EpilogueFunctor.LinearCombinationSilu: "cutlass::epilogue::thread::LinearCombinationSilu",
EpilogueFunctor.LinearCombinationHardSwish: "cutlass::epilogue::thread::LinearCombinationHardSwish",
EpilogueFunctor.LinearCombinationResidualBlock: "cutlass::epilogue::thread::LinearCombinationResidualBlock",
}


Expand Down
48 changes: 47 additions & 1 deletion python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Patterns supported CUTLASS."""
from functools import partial
from tvm import relay
from tvm.ir.transform import Sequential, PassContext
from tvm.relay import transform
Expand Down Expand Up @@ -89,6 +90,19 @@ def make_conv2d_pattern(with_bias=False, with_act=None):
return conv2d_out


def make_residual_block_pattern(tensor_op_out, binary_op="add", with_act="relu"):
"""Add pattern for residual blocks."""
residual_input = wildcard()
binary_out = is_op(binary_op)(tensor_op_out, residual_input) | is_op(binary_op)(
residual_input, tensor_op_out
)

if with_act is not None and with_act == "relu":
return is_op("nn.relu")(binary_out)

return binary_out


def check_dtype(lhs, rhs):
"""Check if dtypes in the given workload are supported by CUTLASS."""
# Only fp16 inputs are supported for now.
Expand Down Expand Up @@ -139,6 +153,25 @@ def check_conv2d(call):
return not is_depthwise_conv2d(IC, OC, conv2d.attrs.groups)


def check_conv2d_residual(call, binary_op):
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
conv2d = get_root_call(call, "nn.conv2d")
if not check_conv2d(call):
return False

residual_binop = get_root_call(call, binary_op)
lhs = residual_binop.args[0]
rhs = residual_binop.args[1]

# residual_input is pattern-matched as a wildcard. Make sure it does not sit between
# residual binary op and the root conv2d of this pattern.
# If the root conv2d is the parent of both lhs and rhs, we should reject this pattern.
if get_root_call(lhs, "nn.conv2d") == conv2d and get_root_call(rhs, "nn.conv2d") == conv2d:
return False

return all(x == y for (x, y) in zip(lhs.checked_type.shape, rhs.checked_type.shape))


def partition_for_cutlass(mod, params=None):
"""Partition the input module into CUTLASS-supported subgraphs."""
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm)
Expand Down Expand Up @@ -189,7 +222,20 @@ def partition_for_cutlass(mod, params=None):
("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
]

cutlass_patterns = dense_patterns + conv2d_patterns
residual_block_patterns = []

for with_act, postfix in [("relu", "_relu"), (None, "")]:
for name, pat, _ in conv2d_patterns[:-1]:
for bin_op in ["add", "multiply"]:
residual_block_patterns.append(
(
name + "_residual_" + bin_op + postfix,
make_residual_block_pattern(pat, bin_op, with_act=with_act),
partial(check_conv2d_residual, binary_op=bin_op),
)
)

cutlass_patterns = residual_block_patterns + dense_patterns + conv2d_patterns

if params is not None:
mod["main"] = bind_params_by_name(mod["main"], params)
Expand Down
79 changes: 70 additions & 9 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ Str2StrMap Conv2dArgs(const Map<String, ObjectRef>& attrs) {
}

std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
const std::vector<std::string>& func_args) {
const std::vector<std::string>& func_args, bool has_residual_block = false) {
bool has_bias = attrs.at("op_type").find("bias") != std::string::npos;
bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid" &&
attrs.at("op_type") != "cutlass.conv2d_bias_silu" &&
Expand All @@ -268,8 +268,8 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n");
CutlassPrint(conv2d_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n");
CutlassPrint(conv2d_decl, "using ElementOutput = " + attrs.at("ElementOutput") + ";\n");

CutlassPrint(conv2d_decl, "using ElementComputeEpilogue = " + attrs.at("ElementOutput") + ";\n");

CutlassPrint(conv2d_decl, attrs.at("op_def"));
CutlassPrint(conv2d_decl, "using Operation_" + attrs.at("op_name") +
" = cutlass::conv::device::ImplicitGemmConvolution<" +
Expand Down Expand Up @@ -308,14 +308,18 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
ICHECK(func_args.size() >= 2);
CutlassPrint(conv2d_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n");
CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n");
if (has_bias) {
if (has_residual_block) {
ICHECK(func_args.size() >= 4);
CutlassPrint(conv2d_decl, "void* ptr_bias = (void*)(" + func_args[2] + "->data);\n");
CutlassPrint(conv2d_decl, "void* ptr_residual = (void*)(" + func_args[3] + "->data);\n");
} else if (has_bias) {
ICHECK(func_args.size() >= 3);
CutlassPrint(conv2d_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n");
}

CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n");
CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n");
if (has_bias && no_bias_scaling) {
if (has_bias && no_bias_scaling && !has_residual_block) {
CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n");
} else {
CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(1);\n");
Expand All @@ -326,24 +330,38 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl,
"TensorNHWC layout_B(TensorNHWC::packed(cutlass::make_Coord(K, R, S, C)));\n");
CutlassPrint(conv2d_decl,
"TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n");
"TensorNHWC layout_C(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n\n");
CutlassPrint(conv2d_decl,
"TensorNHWC layout_D(TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K)));\n\n");

CutlassPrint(conv2d_decl, "typename Conv2d::Arguments arguments{\n");
CutlassPrint(conv2d_decl, " problem_size,\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputA*>(ptr_a), layout_A},\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputB*>(ptr_b), layout_B},\n");
if (has_bias) {

if (has_residual_block) {
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_residual), layout_C},\n");
} else if (has_bias) {
CutlassPrint(
conv2d_decl,
" {static_cast<ElementOutput*>(ptr_c_bias), cutlass::layout::TensorNHWC::Stride(0)},\n");
} else {
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out), layout_C},\n");
}
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
if (has_bias && no_bias_scaling) {

CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_D},\n");

if (has_residual_block) {
CutlassPrint(conv2d_decl, "{alpha, beta},\n");
CutlassPrint(conv2d_decl, "cutlass::conv::SplitKMode::kSerial,\n"); // split_k_slices
CutlassPrint(conv2d_decl, "static_cast<ElementOutput*>(ptr_bias),\n");
CutlassPrint(conv2d_decl, "nullptr, 0, K};\n");
} else if (has_bias && no_bias_scaling) {
CutlassPrint(conv2d_decl, " {alpha}\n};\n");
} else {
CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n");
}

CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n");

CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n");
Expand Down Expand Up @@ -432,6 +450,21 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
return arg_names;
}

bool IsConv2dResidualBlock(const std::string& func_name) {
return func_name.find("conv2d") != std::string::npos &&
func_name.find("residual") != std::string::npos;
}

// Is node `x` an ancestor of `y`?
bool IsAncestor(const CallNode* x, const CallNode* y) {
if (x == y) return true;
for (auto arg : y->args) {
const CallNode* arg_ptr = arg.as<CallNode>();
if (arg_ptr && IsAncestor(x, arg_ptr)) return true;
}
return false;
}

GenerateBodyOutput GenerateCompositeFunctionCall(const FunctionNode* callee,
const CallNode* caller) {
const auto pattern_name = callee->GetAttr<runtime::String>(attr::kComposite);
Expand Down Expand Up @@ -515,6 +548,30 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", add_or_bias_add, "multiply"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_hardswish", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (IsConv2dResidualBlock(pattern_name.value())) {
const CallNode* current_call = callee->body.as<CallNode>();
bool has_relu = current_call->args.size() == 1;
const CallNode* binop = has_relu ? current_call->args[0].as<CallNode>() : current_call;
ICHECK(binop->args.size() == 2);
// Figure out which of the first or second argument corresponds to the residual input
// The root conv2d call can be reached via the other input of the binary op
int residual_index;
if (binop->args[1].as<VarNode>()) {
residual_index = 1;
} else if (binop->args[0].as<VarNode>()) {
residual_index = 0;
} else {
const CallNode* lhs = binop->args[0].as<CallNode>();
const CallNode* rhs = binop->args[1].as<CallNode>();
ICHECK(lhs && rhs);
// The residual input should be an ancestor of the non-residual input
residual_index = IsAncestor(rhs, lhs) ? 1 : 0;
}
const auto* non_residual_input = binop->args[!residual_index].as<CallNode>();
const auto* conv2d_call = GetRootCall(non_residual_input, "nn.conv2d");
ICHECK(conv2d_call);
return GenerateBody(conv2d_call, pattern_name.value(), GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
}

LOG(FATAL) << "Unknown composite function: " << pattern_name;
Expand Down Expand Up @@ -560,6 +617,8 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
ret.decl = DenseOp(ext_func_id_, attribute_args, func_args);
} else if (func_name == "cutlass_batch_matmul") {
ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args);
} else if (IsConv2dResidualBlock(func_name)) {
ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args, true);
} else if (func_name.find("conv2d") != std::string::npos) {
ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args);
}
Expand Down Expand Up @@ -623,6 +682,8 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase {
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_sigmoid.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_silu.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_hardswish.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_residual_block.h>\n";
code_stream_ << "#include <cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h>\n";

ICHECK(ref->IsInstance<FunctionNode>());
auto res = GenCutlassFunc(Downcast<Function>(ref));
Expand Down
Loading

0 comments on commit e7f3648

Please sign in to comment.