Skip to content

Commit

Permalink
[CUTLASS] Support conv2d activation fusion (apache#9746)
Browse files Browse the repository at this point in the history
* Add cutlass conv2d activation (bias, relu, sigmoid)

commit e4e273ae74a8e54ab1ae1414ce9b6bfcc2b3d530
Merge: 0489d14 77c9385
Author: Masahiro Masuda <[email protected]>
Date:   Mon Dec 13 11:58:54 2021 +0900

    Merge branch 'partition-constant-unbind' into cutlass-conv2d-fusion

commit 77c9385
Author: Masahiro Masuda <[email protected]>
Date:   Mon Dec 13 11:58:18 2021 +0900

    add test

commit ab01b3a
Author: Masahiro Masuda <[email protected]>
Date:   Mon Dec 13 11:55:06 2021 +0900

    make constant binding in PartitionGraph optional

commit 0489d14
Author: Masahiro Masuda <[email protected]>
Date:   Sun Dec 12 21:52:29 2021 +0900

    support sigmoid fusion (only fp32 accum for now)

commit 3705bbd
Author: Masahiro Masuda <[email protected]>
Date:   Sun Dec 12 20:50:58 2021 +0900

    conv2d fusion test worked

commit 05b51c9
Author: Masahiro Masuda <[email protected]>
Date:   Sun Dec 12 20:34:10 2021 +0900

    fix bias stride

commit 7cf40e7
Author: Masahiro Masuda <[email protected]>
Date:   Sun Dec 12 20:01:21 2021 +0900

    use nobetascaling

commit 274ec02
Author: Masahiro Masuda <[email protected]>
Date:   Sun Dec 12 19:12:58 2021 +0900

    adding fusion support to codegen

commit 0de5ebd
Author: Masahiro Masuda <[email protected]>
Date:   Sun Dec 12 18:39:08 2021 +0900

    partition working

commit c08bb38
Author: Masahiro Masuda <[email protected]>
Date:   Sun Dec 12 17:24:42 2021 +0900

    update test

commit 81bf9e6
Author: Masahiro Masuda <[email protected]>
Date:   Fri Dec 10 13:23:39 2021 +0900

    add fused conv2d pattern

commit 1c0bbb2
Author: Masahiro Masuda <[email protected]>
Date:   Sun Dec 12 18:29:03 2021 +0900

    fix lint

commit 463574c
Author: Masahiro Masuda <[email protected]>
Date:   Sun Dec 12 17:28:38 2021 +0900

    fixed conv2d check

commit 588c5ab
Author: Masahiro Masuda <[email protected]>
Date:   Sun Dec 12 15:05:27 2021 +0900

    update test

commit a447b57
Author: Masahiro Masuda <[email protected]>
Date:   Sun Dec 12 14:54:52 2021 +0900

    speed up profiling by removing initialization

commit 93cd039
Author: Masahiro Masuda <[email protected]>
Date:   Sun Dec 12 08:26:29 2021 +0900

    fixed nhwc cudnn depthwise conv

commit 6db7172
Author: Masahiro Masuda <[email protected]>
Date:   Sat Dec 11 15:39:05 2021 +0900

    add cache

commit f7d17a1
Author: Masahiro Masuda <[email protected]>
Date:   Sat Dec 11 15:05:38 2021 +0900

    removed im2col profiling for conv2d

commit b724f44
Author: Masahiro Masuda <[email protected]>
Date:   Fri Dec 10 22:57:54 2021 +0900

    black

commit fe4687b
Author: Masahiro Masuda <[email protected]>
Date:   Fri Dec 10 22:49:13 2021 +0900

    fixed cmd arguement

commit ab114f5
Author: Masahiro Masuda <[email protected]>
Date:   Fri Dec 10 22:22:19 2021 +0900

    conv2d profiler working

commit 49ee61f
Author: Masahiro Masuda <[email protected]>
Date:   Fri Dec 10 20:26:15 2021 +0900

    add conv2d profiler

commit 49e2c89
Author: Masahiro Masuda <[email protected]>
Date:   Sun Dec 12 08:03:36 2021 +0900

    do not offload depthwise conv2d

commit cd83677
Author: Masahiro Masuda <[email protected]>
Date:   Fri Dec 10 13:20:01 2021 +0900

    lint fix

commit 870823c
Author: Masahiro Masuda <[email protected]>
Date:   Fri Dec 10 12:54:38 2021 +0900

    add comment on IC == 3 case

commit 6b780db
Author: Masahiro Masuda <[email protected]>
Date:   Fri Dec 10 12:48:33 2021 +0900

    check align on N dim

commit 308c4da
Author: Masahiro Masuda <[email protected]>
Date:   Fri Dec 10 12:34:42 2021 +0900

    fixed check functions for fused cases, run infer type before mergecomposite

commit 8d6a1bf
Author: Masahiro Masuda <[email protected]>
Date:   Fri Dec 10 12:10:59 2021 +0900

    test IC=3 convolution

commit ffce47d
Author: Masahiro Masuda <[email protected]>
Date:   Fri Dec 10 12:10:16 2021 +0900

    use align1 kernel for unusual channel cases (IC = 3 etc)

commit 6cdf205
Author: Masahiro Masuda <[email protected]>
Date:   Fri Dec 10 12:06:56 2021 +0900

    add dtype and layout check in parttern match

commit 7743cc6
Author: Masahiro Masuda <[email protected]>
Date:   Fri Dec 10 10:40:53 2021 +0900

    add sm75 kernels to sm80 profilings

commit efceccb
Author: Masahiro Masuda <[email protected]>
Date:   Fri Dec 10 10:40:42 2021 +0900

    skip legalize when batch size is dynamic

commit 65fbc0a
Author: Masahiro Masuda <[email protected]>
Date:   Fri Dec 10 10:36:36 2021 +0900

    bug fix in im2col encoding

* support batch norm fusion
  • Loading branch information
masahi authored and baoxinqi committed Dec 27, 2021
1 parent 84c07fc commit 7de3bae
Show file tree
Hide file tree
Showing 8 changed files with 196 additions and 37 deletions.
9 changes: 9 additions & 0 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def visit_call(self, call):
if str(op) == "nn.conv2d":
self.op_attrs = call.attrs

for arg in call.args:
self.visit(arg)


def select_gemm_kernel(
cutlass_profiler, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing
Expand Down Expand Up @@ -213,6 +216,12 @@ def handle_conv2d(

if op_type == "cutlass.conv2d":
cutlass_op_def = out["opdef"]
elif op_type == "cutlass.conv2d_bias":
cutlass_op_def = out["opdef_bias"]
elif op_type == "cutlass.conv2d_bias_relu":
cutlass_op_def = out["opdef_bias_relu"]
elif op_type == "cutlass.conv2d_bias_sigmoid":
cutlass_op_def = out["opdef_bias_sigmoid"]
else:
raise ValueError("%s pattern is not implemented." % op_type)

Expand Down
35 changes: 27 additions & 8 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,22 @@ class EmitConv2dInstance:
""" Responsible for emitting a CUTLASS template definition."""

def __init__(self):
self.epilogue_default = """
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>"""
self.epilogue_no_beta_scaling = """
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue},
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>"""

self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name} =
Expand All @@ -159,12 +175,7 @@ def __init__(self):
cutlass::gemm::GemmShape<${threadblock_shape_m}, ${threadblock_shape_n}, ${threadblock_shape_k}>,
cutlass::gemm::GemmShape<${warp_shape_m}, ${warp_shape_n}, ${warp_shape_k} >,
cutlass::gemm::GemmShape<${instruction_shape_m}, ${instruction_shape_n}, ${instruction_shape_k}>,
${epilogue_functor}<
${element_c},
${epilogue_vector_length},
${element_accumulator},
${element_epilogue}
>,
${epilogue},
${swizzling_functor}, // cutlass::gemm::threadblock::GemmSplitKIdentityThreadblockSwizzle<>,
${stages},
${math_operator},
Expand All @@ -175,7 +186,7 @@ def __init__(self):
>::Kernel;
"""

def emit(self, operation):
def emit(self, operation, no_beta_scaling=True):
"""Instantiate a Conv2d kernel from given `operation`."""
warp_shape = [
int(
Expand Down Expand Up @@ -237,4 +248,12 @@ def emit(self, operation):
"align_b": str(operation.B.alignment),
}

return substitute_template(self.template, values)
template = substitute_template(
self.template,
{
"epilogue": self.epilogue_no_beta_scaling
if no_beta_scaling
else self.epilogue_default
},
)
return substitute_template(template, values)
9 changes: 5 additions & 4 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,16 @@ def create_conv2d_operator(
op_entry["op"] = op
op_entry["src"] = profiler_emitter.emit(op_entry["opdef"], op.procedural_name())
op_entry["name"] = op.procedural_name()
op_entry["runtime"] = 9999999

# fused ops
for epilogue, opdef in zip(
for epilogue, opdef, no_bias_scaling in zip(
[
EpilogueFunctor.LinearCombinationBias,
EpilogueFunctor.LinearCombinationRelu,
EpilogueFunctor.LinearCombinationSigmoid,
],
["opdef_bias", "opdef_bias_relu"],
["opdef_bias", "opdef_bias_relu", "opdef_bias_sigmoid"],
[True, True, False],
):
op = Conv2dOperation(
ConvKind.Fprop,
Expand All @@ -107,7 +108,7 @@ def create_conv2d_operator(
swizzling_functor_,
)

op_entry[opdef] = kernel_emitter.emit(op)
op_entry[opdef] = kernel_emitter.emit(op, no_bias_scaling)

ret.append(op_entry)

Expand Down
1 change: 0 additions & 1 deletion python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def create_gemm_operator(
DataTypeTag[element_c],
op.leading_dim(),
)
op_entry["runtime"] = 9999999
op_entry["tile_description"] = tile_description
op_entry["alignment"] = alignment
op_entry["data_type"] = data_type
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,13 +148,15 @@ class EpilogueFunctor(enum.Enum):
LinearCombinationRelu = enum_auto()
LinearCombinationBias = enum_auto()
LinearCombinationGelu = enum_auto()
LinearCombinationSigmoid = enum_auto()


EpilogueFunctorTag = {
EpilogueFunctor.LinearCombination: "cutlass::epilogue::thread::LinearCombination",
EpilogueFunctor.LinearCombinationRelu: "cutlass::epilogue::thread::LinearCombinationRelu",
EpilogueFunctor.LinearCombinationBias: "cutlass::epilogue::thread::LinearCombination",
EpilogueFunctor.LinearCombinationGelu: "cutlass::epilogue::thread::LinearCombinationGELU",
EpilogueFunctor.LinearCombinationSigmoid: "cutlass::epilogue::thread::LinearCombinationSigmoid",
}


Expand Down
55 changes: 49 additions & 6 deletions python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
# under the License.
# pylint: disable=invalid-name
"""Patterns supported CUTLASS."""
from tvm.ir.transform import Sequential
from tvm.ir.transform import Sequential, PassContext
from tvm.relay import transform
from tvm.relay.build_module import bind_params_by_name
from ...dataflow_pattern import wildcard, is_op, is_constant


Expand Down Expand Up @@ -57,8 +58,25 @@ def make_batch_matmul_pattern():
return is_op("nn.batch_matmul")(wildcard(), wildcard())


def make_conv2d_pattern():
return is_op("nn.conv2d")(wildcard(), wildcard())
def make_conv2d_pattern(with_bias=False, with_act=None):
"""Create a pattern for dense op followed by activations."""
data = wildcard()
weight = wildcard()
bias = wildcard()
conv2d = is_op("nn.conv2d")(data, weight)
if with_bias:
add_or_bias_add = is_op("add") | is_op("nn.bias_add")
conv2d_out = add_or_bias_add(conv2d, bias)
else:
conv2d_out = conv2d

if with_act is not None:
if with_act == "relu":
return is_op("nn.relu")(conv2d_out)
if with_act == "sigmoid":
return is_op("sigmoid")(conv2d_out)

return conv2d_out


def check_dtype(lhs, rhs):
Expand Down Expand Up @@ -109,7 +127,7 @@ def check_conv2d(call):
return not is_depthwise_conv2d(IC, OC, conv2d.attrs.groups)


def partition_for_cutlass(mod):
def partition_for_cutlass(mod, params=None):
"""Partition the input module into CUTLASS-supported subgraphs."""
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm)
dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None), check_gemm)
Expand All @@ -131,15 +149,40 @@ def partition_for_cutlass(mod):
dense_bias_pat,
dense_pat,
("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul),
# TODO(masahi): Add more conv2d patterns
(
"cutlass.conv2d_bias_relu",
make_conv2d_pattern(with_bias=True, with_act="relu"),
check_conv2d,
),
(
"cutlass.conv2d_bias_sigmoid",
make_conv2d_pattern(with_bias=True, with_act="sigmoid"),
check_conv2d,
),
("cutlass.conv2d_bias", make_conv2d_pattern(with_bias=True), check_conv2d),
("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
]

if params is not None:
mod["main"] = bind_params_by_name(mod["main"], params)
remove_bn_pass = Sequential(
[
transform.InferType(),
transform.SimplifyInference(),
transform.FoldConstant(),
transform.FoldScaleAxis(),
]
)
with PassContext(opt_level=3):
mod = remove_bn_pass(mod)

seq = Sequential(
[
transform.InferType(),
transform.MergeComposite(cutlass_patterns),
transform.AnnotateTarget(["cutlass"]),
transform.PartitionGraph(),
transform.PartitionGraph(bind_constants=False),
]
)

return seq(mod)
56 changes: 51 additions & 5 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ Str2StrMap Conv2dArgs(const Map<String, ObjectRef>& attrs) {

std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
const std::vector<std::string>& func_args) {
bool has_bias = attrs.at("op_type") == "cutlass.conv2d_bias" ||
attrs.at("op_type") == "cutlass.conv2d_bias_relu" ||
attrs.at("op_type") == "cutlass.conv2d_bias_sigmoid";
bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid";

std::ostringstream conv2d_decl;
CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n");
CutlassPrint(conv2d_decl, "using ElementInputB = " + attrs.at("ElementInputB") + ";\n");
Expand Down Expand Up @@ -307,10 +312,18 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
ICHECK(func_args.size() >= 2);
CutlassPrint(conv2d_decl, "void* ptr_a = (void*)(" + func_args[0] + "->data);\n");
CutlassPrint(conv2d_decl, "void* ptr_b = (void*)(" + func_args[1] + "->data);\n");
if (has_bias) {
ICHECK(func_args.size() >= 3);
CutlassPrint(conv2d_decl, "void* ptr_c_bias = (void*)(" + func_args[2] + "->data);\n");
}

CutlassPrint(conv2d_decl, "void* ptr_out = (void*)(out0->data);\n");
CutlassPrint(conv2d_decl, "ElementComputeEpilogue alpha = ElementComputeEpilogue(1);\n");
CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n");

if (has_bias && no_bias_scaling) {
CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(0);\n");
} else {
CutlassPrint(conv2d_decl, "ElementComputeEpilogue beta = ElementComputeEpilogue(1);\n");
}
CutlassPrint(conv2d_decl, "using cutlass::layout::TensorNHWC;\n");
CutlassPrint(conv2d_decl,
"TensorNHWC layout_A(TensorNHWC::packed(cutlass::make_Coord(N, H, W, C)));\n");
Expand All @@ -322,9 +335,19 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
CutlassPrint(conv2d_decl, " problem_size,\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputA*>(ptr_a), layout_A},\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementInputB*>(ptr_b), layout_B},\n");
if (has_bias) {
CutlassPrint(
conv2d_decl,
" {static_cast<ElementOutput*>(ptr_c_bias), cutlass::layout::TensorNHWC::Stride(0)},\n");
} else {
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
}
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
CutlassPrint(conv2d_decl, " {static_cast<ElementOutput*>(ptr_out),layout_C},\n");
CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n");
if (has_bias && no_bias_scaling) {
CutlassPrint(conv2d_decl, " {alpha}\n};\n");
} else {
CutlassPrint(conv2d_decl, "{alpha, beta}\n};\n");
}
CutlassPrint(conv2d_decl, "Conv2d conv2d_op;\n");

CutlassPrint(conv2d_decl, "size_t workspace_size = conv2d_op.get_workspace_size(arguments);\n");
Expand Down Expand Up @@ -461,6 +484,27 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
const auto* conv2d_call = GetRootCall(callee->body.as<CallNode>(), 0, {"nn.conv2d"});
return GenerateBody(conv2d_call, "cutlass_conv2d", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d_bias") {
const CallNode* current_call = callee->body.as<CallNode>();
std::string add_or_bias_add = current_call->op.as<OpNode>()->name;
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 1, {"nn.conv2d", add_or_bias_add});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d_bias_relu") {
const CallNode* current_call = callee->body.as<CallNode>();
std::string add_or_bias_add = current_call->args[0].as<CallNode>()->op.as<OpNode>()->name;
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", add_or_bias_add, "nn.relu"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_relu", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d_bias_sigmoid") {
const CallNode* current_call = callee->body.as<CallNode>();
std::string add_or_bias_add = current_call->args[0].as<CallNode>()->op.as<OpNode>()->name;
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", add_or_bias_add, "sigmoid"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_sigmoid", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
}

LOG(FATAL) << "Unknown composite function: " << pattern_name;
Expand Down Expand Up @@ -507,7 +551,9 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
ret.decl = DenseOp(ext_func_id_, attribute_args, func_args);
} else if (func_name == "cutlass_batch_matmul") {
ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args);
} else if (func_name == "cutlass_conv2d") {
} else if (func_name == "cutlass_conv2d" || func_name == "cutlass_conv2d_bias" ||
func_name == "cutlass_conv2d_bias_relu" ||
func_name == "cutlass_conv2d_bias_sigmoid") {
ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args);
}

Expand Down
Loading

0 comments on commit 7de3bae

Please sign in to comment.