Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUTLASS] Conv2d activation fusion, part 2: Sigmoid fp16, SiLU and HardSwish #9795

Merged
merged 10 commits into from
Dec 23, 2021
Merged
Prev Previous commit
Next Next commit
silu fusion supported
  • Loading branch information
masahi committed Dec 22, 2021
commit f23d38dec6d5bd858075289f6db2db5e4e00a08e
2 changes: 2 additions & 0 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ def handle_conv2d(
cutlass_op_def = out["opdef_bias_relu"]
elif op_type == "cutlass.conv2d_bias_sigmoid":
cutlass_op_def = out["opdef_bias_sigmoid"]
elif op_type == "cutlass.conv2d_bias_silu":
cutlass_op_def = out["opdef_bias_silu"]
else:
raise ValueError("%s pattern is not implemented." % op_type)

Expand Down
5 changes: 3 additions & 2 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,10 @@ def create_conv2d_operator(
EpilogueFunctor.LinearCombinationBias,
EpilogueFunctor.LinearCombinationRelu,
EpilogueFunctor.LinearCombinationSigmoid,
EpilogueFunctor.LinearCombinationSilu,
],
["opdef_bias", "opdef_bias_relu", "opdef_bias_sigmoid"],
[True, True, False],
["opdef_bias", "opdef_bias_relu", "opdef_bias_sigmoid", "opdef_bias_silu"],
[True, True, False, False],
):
op = Conv2dOperation(
ConvKind.Fprop,
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ class EpilogueFunctor(enum.Enum):
LinearCombinationBias = enum_auto()
LinearCombinationGelu = enum_auto()
LinearCombinationSigmoid = enum_auto()
LinearCombinationSilu = enum_auto()


EpilogueFunctorTag = {
Expand All @@ -157,6 +158,7 @@ class EpilogueFunctor(enum.Enum):
EpilogueFunctor.LinearCombinationBias: "cutlass::epilogue::thread::LinearCombination",
EpilogueFunctor.LinearCombinationGelu: "cutlass::epilogue::thread::LinearCombinationGELU",
EpilogueFunctor.LinearCombinationSigmoid: "cutlass::epilogue::thread::LinearCombinationSigmoid",
EpilogueFunctor.LinearCombinationSilu: "cutlass::epilogue::thread::LinearCombinationSilu",
}


Expand Down
7 changes: 7 additions & 0 deletions python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ def make_conv2d_pattern(with_bias=False, with_act=None):
return is_op("nn.relu")(conv2d_out)
if with_act == "sigmoid":
return is_op("sigmoid")(conv2d_out)
if with_act == "silu":
return is_op("multiply")(conv2d_out, is_op("sigmoid")(conv2d_out))

return conv2d_out

Expand Down Expand Up @@ -149,6 +151,11 @@ def partition_for_cutlass(mod, params=None):
dense_bias_pat,
dense_pat,
("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul),
(
"cutlass.conv2d_bias_silu",
make_conv2d_pattern(with_bias=True, with_act="silu"),
check_conv2d,
),
(
"cutlass.conv2d_bias_relu",
make_conv2d_pattern(with_bias=True, with_act="relu"),
Expand Down
18 changes: 15 additions & 3 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,10 @@ std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
const std::vector<std::string>& func_args) {
bool has_bias = attrs.at("op_type") == "cutlass.conv2d_bias" ||
attrs.at("op_type") == "cutlass.conv2d_bias_relu" ||
attrs.at("op_type") == "cutlass.conv2d_bias_sigmoid";
bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid";
attrs.at("op_type") == "cutlass.conv2d_bias_sigmoid" ||
attrs.at("op_type") == "cutlass.conv2d_bias_silu";
bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid" &&
attrs.at("op_type") != "cutlass.conv2d_bias_silu";

std::ostringstream conv2d_decl;
CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n");
Expand Down Expand Up @@ -505,6 +507,13 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", add_or_bias_add, "sigmoid"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_sigmoid", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d_bias_silu") {
const CallNode* current_call = callee->body.as<CallNode>();
std::string add_or_bias_add = current_call->args[0].as<CallNode>()->op.as<OpNode>()->name;
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", add_or_bias_add, "multiply"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_silu", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
}

LOG(FATAL) << "Unknown composite function: " << pattern_name;
Expand Down Expand Up @@ -553,7 +562,8 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args);
} else if (func_name == "cutlass_conv2d" || func_name == "cutlass_conv2d_bias" ||
func_name == "cutlass_conv2d_bias_relu" ||
func_name == "cutlass_conv2d_bias_sigmoid") {
func_name == "cutlass_conv2d_bias_sigmoid" ||
func_name == "cutlass_conv2d_bias_silu") {
ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args);
}

Expand Down Expand Up @@ -613,6 +623,8 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase {
code_stream_ << "#include <cutlass/conv/device/implicit_gemm_convolution.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_bias_relu.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_gelu.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_sigmoid.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_silu.h>\n";

ICHECK(ref->IsInstance<FunctionNode>());
auto res = GenCutlassFunc(Downcast<Function>(ref));
Expand Down
13 changes: 12 additions & 1 deletion tests/python/contrib/test_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,11 @@ def get_conv2d_nchw_bias_sigmoid(d_shape, w_shape, padding, out_dtype="float16")
return relay.sigmoid(get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype))


def get_conv2d_nchw_bias_silu(d_shape, w_shape, padding, out_dtype="float16"):
conv_out = get_conv2d_nchw_bias(d_shape, w_shape, padding, out_dtype=out_dtype)
return conv_out * relay.sigmoid(conv_out)


def profile_and_build(mod, params, sm, tmp_dir="./tmp", lib_path="compile.so"):
mod = partition_for_cutlass(mod)
mod, num_cutlass_partition = tune_cutlass_kernels(
Expand Down Expand Up @@ -443,6 +448,12 @@ def test_conv2d_fusion():
mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
)

mod_nchw = get_conv2d_nchw_bias_silu(d_shape, w_shape, padding, out_dtype="float32")
verify_conv2d(
mod_nchw, mod_nchw, d_shape, w_shape, sm=80, atol=1e-5, rtol=1e-5, run_benchmark=False
)


if __name__ == "__main__":
pytest.main([__file__])
# pytest.main([__file__])
test_conv2d_fusion()