Skip to content

Commit

Permalink
[CUTLASS] Initial conv2d support (apache#9595)
Browse files Browse the repository at this point in the history
* Add initial conv generator

* added conv2d pattern

* profile by gemm profiler

* remove conv2d profiler for now

* remove unused code

* add default

* minor fix, profiling working

* start codegen

* generated code compiled

* fixed layout initialization

* matched with autotvm tensorcore result

* test refactor

* minor cleanup

* remove iteration algo "Analytic"

* add test for dynamic batch conv2d

* pass dl tensor as output too

* support conv2d dynamic shape in codegen

* test working

* lint

* simplify codegen

* fix weird formatting

* typo fix

* check if cutlass is enabled in the test

* simplify gen_conv2d.py
  • Loading branch information
masahi authored and yangulei committed Jan 11, 2022
1 parent 89acc7b commit 19a7256
Show file tree
Hide file tree
Showing 10 changed files with 776 additions and 25 deletions.
90 changes: 77 additions & 13 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from tvm import runtime, relay
from tvm.contrib.nvcc import find_cuda_path, get_cuda_version
from .gen_gemm import CutlassGemmProfiler
from .gen_conv2d import CutlassConv2DProfiler

logger = logging.getLogger("cutlass")

Expand Down Expand Up @@ -65,7 +66,7 @@ def _get_cutlass_compile_options(sm, threads):
return kwargs


class GemmAnnotator(tvm.relay.ExprVisitor):
class OpAnnotator(tvm.relay.ExprVisitor):
"""Annotates partitioned functions with shape and dtype information."""

def __init__(self):
Expand All @@ -81,6 +82,10 @@ def visit_call(self, call):
self.signature["arg%d_dtype" % i] = arg.checked_type.dtype
self.signature["ret_shape"] = op.ret_type.shape
self.signature["ret_dtype"] = op.ret_type.dtype
self.visit(op.body)

if str(op) == "nn.conv2d":
self.op_attrs = call.attrs


def select_gemm_kernel(
Expand Down Expand Up @@ -125,13 +130,18 @@ def handle_batch_matmul(
else:
raise ValueError("%s pattern is not implemented." % op_type)

assert "tn_align" in out["name"], "Only supports (row_major, col_major) input layout for now."

return {
"batch": arg0_shape[0],
"batch_stride_A": arg0_shape[1] * arg0_shape[2],
"batch_stride_B": arg1_shape[1] * arg1_shape[2],
"batch_stride_C": arg0_shape[1] * arg1_shape[1],
"cutlass_op_def": cutlass_op_def,
"cutlass_op_name": out["name"],
"lda": "K",
"ldb": "K",
"ldc": "N",
}


Expand All @@ -158,6 +168,50 @@ def handle_dense(
else:
raise ValueError("%s pattern is not implemented." % op_type)

assert "tn_align" in out["name"], "Only supports (row_major, col_major) input layout for now."

return {
"cutlass_op_def": cutlass_op_def,
"cutlass_op_name": out["name"],
"lda": "K",
"ldb": "K",
"ldc": "N",
}


def handle_conv2d(
cutlass_profiler,
op_type,
d_shape,
w_shape,
out_shape,
out_dtype,
profile_all,
use_multiprocessing,
):
"""Profile and select a kernel for conv2d op workload."""
if any(isinstance(s, tvm.tir.Any) for s in d_shape):
out = cutlass_profiler.get_default(out_dtype)
logger.info("Picked the default kernel %s", out["name"])
else:
out = cutlass_profiler.profile(
d_shape,
w_shape,
out_shape,
out_dtype,
profile_all=profile_all,
use_multiprocessing=use_multiprocessing,
)
if profile_all:
logger.info("The best kernel is %s", out["name"])
else:
logger.info("Picked the first kernel found %s", out["name"])

if op_type == "cutlass.conv2d":
cutlass_op_def = out["opdef"]
else:
raise ValueError("%s pattern is not implemented." % op_type)

return {
"cutlass_op_def": cutlass_op_def,
"cutlass_op_name": out["name"],
Expand Down Expand Up @@ -195,12 +249,13 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
num_cutlass_partition : int
The number of partitioned functions created for CUTLASS.
"""
cutlass_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir)
gemm_profiler = CutlassGemmProfiler(sm, _get_cutlass_path(), tmp_dir)
conv2d_profiler = CutlassConv2DProfiler(sm, _get_cutlass_path(), tmp_dir)
num_cutlass_partition = 0
for var in mod.get_global_vars():
fun_name = var.name_hint
func = mod[fun_name]
annotator = GemmAnnotator()
annotator = OpAnnotator()
if "cutlass" in fun_name:
num_cutlass_partition += 1
annotator.visit(func)
Expand All @@ -213,10 +268,26 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
arg0_shape = new_attrs["arg0_shape"]
arg1_shape = new_attrs["arg1_shape"]

if "batch_matmul" in op_type:
if "conv2d" in op_type:
new_attrs["padding"] = annotator.op_attrs.padding
new_attrs["strides"] = annotator.op_attrs.strides
new_attrs["dilation"] = annotator.op_attrs.dilation
new_attrs.update(
handle_conv2d(
conv2d_profiler,
op_type,
arg0_shape,
arg1_shape,
annotator.signature["ret_shape"],
out_dtype,
profile_all,
use_multiprocessing,
)
)
elif "batch_matmul" in op_type:
new_attrs.update(
handle_batch_matmul(
cutlass_profiler,
gemm_profiler,
op_type,
arg0_shape,
arg1_shape,
Expand All @@ -228,7 +299,7 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
elif "dense" in op_type:
new_attrs.update(
handle_dense(
cutlass_profiler,
gemm_profiler,
op_type,
arg0_shape,
arg1_shape,
Expand All @@ -240,13 +311,6 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
else:
raise ValueError("%s unsupported composite" % op_type)

if new_attrs["cutlass_op_name"].find("_tn_align") > 0:
new_attrs["lda"] = "K"
new_attrs["ldb"] = "K"
new_attrs["ldc"] = "N"
else:
raise ValueError("%s unsupported operation" % new_attrs["cutlass_op_name"])

new_attrs = tvm.ir.make_node("DictAttrs", **new_attrs)
new_func = relay.Function(
func.params,
Expand Down
Loading

0 comments on commit 19a7256

Please sign in to comment.