Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUTLASS] Conv2d activation fusion, part 2: Sigmoid fp16, SiLU and HardSwish #9795

Merged
merged 10 commits into from
Dec 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/cutlass
Submodule cutlass updated 163 files
30 changes: 25 additions & 5 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def _get_cutlass_path():
return cutlass_path


def _get_cutlass_compile_options(sm, threads):
def _get_cutlass_compile_options(sm, threads, use_fast_math=False):
cutlass_root = _get_cutlass_path()
cutlass_include = os.path.join(cutlass_root, "include")
cutlass_util_include = os.path.join(cutlass_root, "tools/util/include")
Expand All @@ -58,6 +58,8 @@ def _get_cutlass_compile_options(sm, threads):
"-I" + cutlass_include,
"-I" + cutlass_util_include,
]
if use_fast_math:
kwargs["options"].append("-DCUTLASS_USE_TANH_FOR_SIGMOID")
cuda_path = find_cuda_path()
cuda_ver = get_cuda_version(cuda_path)
if cuda_ver >= 11.2:
Expand Down Expand Up @@ -222,6 +224,10 @@ def handle_conv2d(
cutlass_op_def = out["opdef_bias_relu"]
elif op_type == "cutlass.conv2d_bias_sigmoid":
cutlass_op_def = out["opdef_bias_sigmoid"]
elif op_type == "cutlass.conv2d_bias_silu":
cutlass_op_def = out["opdef_bias_silu"]
elif op_type == "cutlass.conv2d_bias_hardswish":
cutlass_op_def = out["opdef_bias_hardswish"]
else:
raise ValueError("%s pattern is not implemented." % op_type)

Expand Down Expand Up @@ -339,7 +345,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
return mod, num_cutlass_partition


def build_cutlass_kernels(lib, sm, tmp_dir="./tmp", lib_path="compile.so", threads=-1):
def build_cutlass_kernels(
lib, sm, tmp_dir="./tmp", lib_path="compile.so", threads=-1, use_fast_math=False
):
"""Compile CUTLASS kernels in lib and return the runtime module ready to run.
Parameters
Expand All @@ -361,18 +369,27 @@ def build_cutlass_kernels(lib, sm, tmp_dir="./tmp", lib_path="compile.so", threa
The number of threads to use for compiling generated kernels. Only available for
CUDA 11.2 or later. Use all physical cores by default.
use_fast_math : bool, optional
Whether or not to use faster but less accurate math intrinsics.
Returns
-------
updated_lib : runtime.Module
The updated module with compiled cutlass kernels.
"""
kwargs = _get_cutlass_compile_options(sm, threads)
kwargs = _get_cutlass_compile_options(sm, threads, use_fast_math)
lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs)
return runtime.load_module(lib_path)


def build_cutlass_kernels_vm(
vm_exec, sm, tmp_dir="./tmp", lib_path="compile.so", vmcode_path="vmcode.ro", threads=-1
vm_exec,
sm,
tmp_dir="./tmp",
lib_path="compile.so",
vmcode_path="vmcode.ro",
threads=-1,
use_fast_math=False,
):
"""Compile CUTLASS kernels in vm_exec and return a VM executable ready to run.
Expand All @@ -398,13 +415,16 @@ def build_cutlass_kernels_vm(
The number of threads to use for compiling generated kernels. Only available for
CUDA 11.2 or later. Use all physical cores by default.
use_fast_math : bool, optional
Whether or not to use faster but less accurate math intrinsics.
Returns
-------
updated_vm_exec: vm.Executable
The updated exectuable with compiled cutlass kernels.
"""
code, lib = vm_exec.save()
kwargs = _get_cutlass_compile_options(sm, threads)
kwargs = _get_cutlass_compile_options(sm, threads, use_fast_math)
lib_path = os.path.join(tmp_dir, lib_path)
vmcode_path = os.path.join(tmp_dir, vmcode_path)
lib.export_library(lib_path, workspace_dir=tmp_dir, **kwargs)
Expand Down
12 changes: 10 additions & 2 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,17 @@ def create_conv2d_operator(
EpilogueFunctor.LinearCombinationBias,
EpilogueFunctor.LinearCombinationRelu,
EpilogueFunctor.LinearCombinationSigmoid,
EpilogueFunctor.LinearCombinationSilu,
EpilogueFunctor.LinearCombinationHardSwish,
],
["opdef_bias", "opdef_bias_relu", "opdef_bias_sigmoid"],
[True, True, False],
[
"opdef_bias",
"opdef_bias_relu",
"opdef_bias_sigmoid",
"opdef_bias_silu",
"opdef_bias_hardswish",
],
[True, True, False, False, False],
):
op = Conv2dOperation(
ConvKind.Fprop,
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=invalid-name,line-too-long
"""Various type definitions to help instantiate CUTLASS kernels."""
import re
import enum
Expand Down Expand Up @@ -149,6 +149,8 @@ class EpilogueFunctor(enum.Enum):
LinearCombinationBias = enum_auto()
LinearCombinationGelu = enum_auto()
LinearCombinationSigmoid = enum_auto()
LinearCombinationSilu = enum_auto()
LinearCombinationHardSwish = enum_auto()


EpilogueFunctorTag = {
Expand All @@ -157,6 +159,8 @@ class EpilogueFunctor(enum.Enum):
EpilogueFunctor.LinearCombinationBias: "cutlass::epilogue::thread::LinearCombination",
EpilogueFunctor.LinearCombinationGelu: "cutlass::epilogue::thread::LinearCombinationGELU",
EpilogueFunctor.LinearCombinationSigmoid: "cutlass::epilogue::thread::LinearCombinationSigmoid",
EpilogueFunctor.LinearCombinationSilu: "cutlass::epilogue::thread::LinearCombinationSilu",
EpilogueFunctor.LinearCombinationHardSwish: "cutlass::epilogue::thread::LinearCombinationHardSwish",
}


Expand Down
7 changes: 6 additions & 1 deletion python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -1735,14 +1735,19 @@ def pad(inputs, input_types):
paddings = [paddings[i : i + 2] for i in range(0, len(paddings), 2)]

const_paddings = []
non_zero_found = False
for pad in paddings:
const_paddings.append([])
for p in pad:
if not isinstance(p, int):
p = int(_infer_value(p, {}).numpy())
const_paddings[-1].append(p)
if p != 0:
non_zero_found = True

if mode == "constant":
if not non_zero_found:
return data
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a minor optimization but it non-trivially helped performance on the DETR model. @comaniac

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm interesting. I didn't notice that we may have pad ops that actually pad nothing.

elif mode == "constant":
return _op.nn.pad(data, const_paddings, pad_value=inputs[2], pad_mode=mode)
else:
return _op.nn.pad(data, const_paddings, pad_mode=mode)
Expand Down
21 changes: 20 additions & 1 deletion python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,15 @@ def make_conv2d_pattern(with_bias=False, with_act=None):
return is_op("nn.relu")(conv2d_out)
if with_act == "sigmoid":
return is_op("sigmoid")(conv2d_out)
if with_act == "silu":
return is_op("multiply")(conv2d_out, is_op("sigmoid")(conv2d_out))
if with_act == "hardswish":
rhs = is_op("divide")(
is_op("clip")(is_op("add")(conv2d_out, is_constant())), is_constant()
)
return is_op("multiply")(conv2d_out, rhs)

raise ValueError("Unknown activation %s." % with_act)

return conv2d_out

Expand Down Expand Up @@ -149,6 +158,16 @@ def partition_for_cutlass(mod, params=None):
dense_bias_pat,
dense_pat,
("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul),
(
"cutlass.conv2d_bias_hardswish",
make_conv2d_pattern(with_bias=True, with_act="hardswish"),
check_conv2d,
),
(
"cutlass.conv2d_bias_silu",
make_conv2d_pattern(with_bias=True, with_act="silu"),
check_conv2d,
),
(
"cutlass.conv2d_bias_relu",
make_conv2d_pattern(with_bias=True, with_act="relu"),
Expand Down Expand Up @@ -180,7 +199,7 @@ def partition_for_cutlass(mod, params=None):
[
transform.InferType(),
transform.MergeComposite(cutlass_patterns),
transform.AnnotateTarget(["cutlass"]),
transform.AnnotateTarget(["cutlass"], include_non_call_ops=False),
transform.PartitionGraph(bind_constants=False),
]
)
Expand Down
38 changes: 24 additions & 14 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,13 +173,9 @@ void AppendGemmExecute(std::ostringstream& gemm_decl, const std::string& kernel)

std::string DenseOp(std::string id, const Str2StrMap& attrs,
const std::vector<std::string>& func_args) {
bool has_bias = false;
bool has_bias = attrs.at("op_type").find("bias") != std::string::npos;
bool is_gelu =
attrs.at("op_type").find("cutlass.dense_bias_gelu") != std::string::npos; // fp32 or fp16
if (attrs.at("op_type") == "cutlass.dense_bias" ||
attrs.at("op_type") == "cutlass.dense_bias_relu" || is_gelu) {
has_bias = true;
}
std::ostringstream gemm_decl;
AppendPrologue(gemm_decl, attrs, func_args, "Gemm", has_bias, is_gelu, 0, 0, 1);

Expand Down Expand Up @@ -263,10 +259,10 @@ Str2StrMap Conv2dArgs(const Map<String, ObjectRef>& attrs) {

std::string Conv2dOp(std::string id, const Str2StrMap& attrs,
const std::vector<std::string>& func_args) {
bool has_bias = attrs.at("op_type") == "cutlass.conv2d_bias" ||
attrs.at("op_type") == "cutlass.conv2d_bias_relu" ||
attrs.at("op_type") == "cutlass.conv2d_bias_sigmoid";
bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid";
bool has_bias = attrs.at("op_type").find("bias") != std::string::npos;
bool no_bias_scaling = attrs.at("op_type") != "cutlass.conv2d_bias_sigmoid" &&
attrs.at("op_type") != "cutlass.conv2d_bias_silu" &&
attrs.at("op_type") != "cutlass.conv2d_bias_hardswish";

std::ostringstream conv2d_decl;
CutlassPrint(conv2d_decl, "using ElementInputA = " + attrs.at("ElementInputA") + ";\n");
Expand Down Expand Up @@ -505,6 +501,20 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", add_or_bias_add, "sigmoid"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_sigmoid", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d_bias_silu") {
const CallNode* current_call = callee->body.as<CallNode>();
std::string add_or_bias_add = current_call->args[0].as<CallNode>()->op.as<OpNode>()->name;
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", add_or_bias_add, "multiply"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_silu", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
} else if (pattern_name == "cutlass.conv2d_bias_hardswish") {
const CallNode* current_call = callee->body.as<CallNode>();
std::string add_or_bias_add = current_call->args[0].as<CallNode>()->op.as<OpNode>()->name;
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 2, {"nn.conv2d", add_or_bias_add, "multiply"});
return GenerateBody(conv2d_call, "cutlass_conv2d_bias_hardswish", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_)));
}

LOG(FATAL) << "Unknown composite function: " << pattern_name;
Expand Down Expand Up @@ -546,14 +556,11 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
ret.outputs.push_back(output);
}
decl_stream << ");";
if (func_name == "cutlass_dense" || func_name == "cutlass_dense_bias" ||
func_name == "cutlass_dense_bias_relu" || func_name == "cutlass_dense_bias_gelu") {
if (func_name.find("dense") != std::string::npos) {
ret.decl = DenseOp(ext_func_id_, attribute_args, func_args);
} else if (func_name == "cutlass_batch_matmul") {
ret.decl = BatchMatmulOp(ext_func_id_, attribute_args, func_args);
} else if (func_name == "cutlass_conv2d" || func_name == "cutlass_conv2d_bias" ||
func_name == "cutlass_conv2d_bias_relu" ||
func_name == "cutlass_conv2d_bias_sigmoid") {
} else if (func_name.find("conv2d") != std::string::npos) {
ret.decl = Conv2dOp(ext_func_id_, attribute_args, func_args);
}

Expand Down Expand Up @@ -613,6 +620,9 @@ class CutlassModuleCodegen : public CSourceModuleCodegenBase {
code_stream_ << "#include <cutlass/conv/device/implicit_gemm_convolution.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_bias_relu.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_gelu.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_sigmoid.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_silu.h>\n";
code_stream_ << "#include <cutlass/epilogue/thread/linear_combination_hardswish.h>\n";

ICHECK(ref->IsInstance<FunctionNode>());
auto res = GenCutlassFunc(Downcast<Function>(ref));
Expand Down
3 changes: 3 additions & 0 deletions src/relay/op/dyn/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,9 @@ bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
int64_t num_axis = dshape.size();

const auto* begin = types[1].as<TensorTypeNode>();
if (begin == nullptr) {
return false;
}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This and the change below in src/relay/op/tensor/transform.cc are the fix for the type inference issue mentioned in "Known issues" section of #9746

No test is added because it is hard to reproduce on a simple test case and the change is trivial.

ICHECK(begin);

// calculate output shape
Expand Down
5 changes: 5 additions & 0 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,11 @@ bool StackRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
<< "cast: expect input type to be TupleType but get " << types[0];
return false;
}
for (auto field : tensor_tuple->fields) {
if (field.as<IncompleteTypeNode>()) {
return false;
}
}
const auto* param = attrs.as<StackAttrs>();
const auto& first = Downcast<TensorType>(tensor_tuple->fields[0]);
const int ndim = static_cast<int>(first->shape.size());
Expand Down
Loading