diff --git a/python/tvm/relay/op/contrib/cutlass.py b/python/tvm/relay/op/contrib/cutlass.py index 4ae529e18dc2c..80c8ebffeb0f9 100644 --- a/python/tvm/relay/op/contrib/cutlass.py +++ b/python/tvm/relay/op/contrib/cutlass.py @@ -56,19 +56,51 @@ def make_batch_matmul_pattern(): def make_conv2d_pattern(): - # TODO(masahi): Check layout and alignment return is_op("nn.conv2d")(wildcard(), wildcard()) +def check_dtype(lhs, rhs): + """Check if dtypes in the given workload are supported by CUTLASS.""" + return lhs.dtype == rhs.dtype and lhs.dtype == "float16" and rhs.dtype == "float16" + + +def check_gemm(call): + """Check if the given dense workload can be offloaded to CUTLASS.""" + lhs = call.args[0].checked_type + rhs = call.args[1].checked_type + return check_dtype(lhs, rhs) + + +def check_batch_matmul(call): + """Check if the given batch_matmul workload can be offloaded to CUTLASS.""" + transpose_a = call.attrs.transpose_a + transpose_b = call.attrs.transpose_b + return check_gemm(call) and transpose_a == False and transpose_b == True + + +def check_conv2d(call): + """Check if the given conv2d workload can be offloaded to CUTLASS.""" + data_layout = call.attrs.data_layout + kernel_layout = call.attrs.kernel_layout + data = call.args[0].checked_type + weight = call.args[1].checked_type + return data_layout == "NHWC" and kernel_layout == "OHWI" and check_dtype(data, weight) + + def partition_for_cutlass(mod): """Partition the input module into CUTLASS-supported subgraphs.""" - dense_pat = ("cutlass.dense", make_gemm_pattern(False, None)) - dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None)) - dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu")) - dense_bias_gelu_fp16_pat = ("cutlass.dense_bias_gelu_fp16", make_gemm_pattern(True, "gelu")) + dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm) + dense_bias_pat = ("cutlass.dense_bias", make_gemm_pattern(True, None), check_gemm) + dense_bias_relu_pat = ("cutlass.dense_bias_relu", make_gemm_pattern(True, "relu"), check_gemm) + dense_bias_gelu_fp16_pat = ( + "cutlass.dense_bias_gelu_fp16", + make_gemm_pattern(True, "gelu"), + check_gemm, + ) dense_bias_gelu_fp32_pat = ( "cutlass.dense_bias_gelu_fp32", make_gemm_pattern(True, "gelu", out_dtype="float32"), + check_gemm, ) cutlass_patterns = [ dense_bias_gelu_fp16_pat, @@ -76,9 +108,9 @@ def partition_for_cutlass(mod): dense_bias_relu_pat, dense_bias_pat, dense_pat, - ("cutlass.batch_matmul", make_batch_matmul_pattern()), + ("cutlass.batch_matmul", make_batch_matmul_pattern(), check_batch_matmul), # TODO(masahi): Add more conv2d patterns - ("cutlass.conv2d", make_conv2d_pattern()), + ("cutlass.conv2d", make_conv2d_pattern(), check_conv2d), ] mod = transform.MergeComposite(cutlass_patterns)(mod) mod = transform.AnnotateTarget(["cutlass"])(mod)