Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUTLASS] Support more kernels: int8, tf32, and 3xtf32 #9899

Merged
merged 22 commits into from
Jan 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 83 additions & 8 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,12 +94,25 @@ def visit_call(self, call):


def select_gemm_kernel(
cutlass_profiler, op_type, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing
cutlass_profiler,
op_type,
MM,
KK,
NN,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
batched,
profile_all,
use_multiprocessing,
):
"""Run CUTLASS profiler to select the best kernel, or return the default one for dynamic
workloads."""
if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]):
out = cutlass_profiler.get_default(op_type, out_dtype, batched=batched)
out = cutlass_profiler.get_default(
op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32, batched=batched
)
name, cutlass_op_def = out["name"], out["opdef"]
logger.info("Picked the default kernel %s", name)
else:
Expand All @@ -109,6 +122,9 @@ def select_gemm_kernel(
NN,
KK,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
batched=batched,
profile_all=profile_all,
use_multiprocessing=use_multiprocessing,
Expand All @@ -122,15 +138,35 @@ def select_gemm_kernel(


def handle_batch_matmul(
cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing
cutlass_profiler,
op_type,
arg0_shape,
arg1_shape,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
use_multiprocessing,
):
"""Profile and select a kernel for batch_matmul op workload."""
MM = arg0_shape[1]
KK = arg0_shape[2]
NN = arg1_shape[1]

name, cutlass_op_def = select_gemm_kernel(
cutlass_profiler, op_type, MM, KK, NN, out_dtype, True, profile_all, use_multiprocessing
cutlass_profiler,
op_type,
MM,
KK,
NN,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
True,
profile_all,
use_multiprocessing,
)

return {
Expand All @@ -147,15 +183,35 @@ def handle_batch_matmul(


def handle_dense(
cutlass_profiler, op_type, arg0_shape, arg1_shape, out_dtype, profile_all, use_multiprocessing
cutlass_profiler,
op_type,
arg0_shape,
arg1_shape,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
use_multiprocessing,
):
"""Profile and select a kernel for dense op workload."""
MM = arg0_shape[0]
KK = arg0_shape[1]
NN = arg1_shape[0]

name, cutlass_op_def = select_gemm_kernel(
cutlass_profiler, op_type, MM, KK, NN, out_dtype, False, profile_all, use_multiprocessing
cutlass_profiler,
op_type,
MM,
KK,
NN,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
False,
profile_all,
use_multiprocessing,
)

assert "tn_align" in name, "Only supports (row_major, col_major) input layout for now."
Expand All @@ -178,12 +234,15 @@ def handle_conv2d(
strides,
dilation,
out_dtype,
data_dtype,
weight_dtype,
use_3xtf32,
profile_all,
use_multiprocessing,
):
"""Profile and select a kernel for conv2d op workload."""
if any(isinstance(s, tvm.tir.Any) for s in d_shape):
out = cutlass_profiler.get_default(op_type, out_dtype)
out = cutlass_profiler.get_default(op_type, out_dtype, data_dtype, weight_dtype, use_3xtf32)
name, cutlass_op_def = out["name"], out["opdef"]
logger.info("Picked the default kernel %s", name)
else:
Expand All @@ -195,6 +254,9 @@ def handle_conv2d(
strides,
dilation,
out_dtype,
data_dtype,
weight_dtype,
use_3xtf32,
profile_all=profile_all,
use_multiprocessing=use_multiprocessing,
)
Expand All @@ -209,7 +271,9 @@ def handle_conv2d(
}


def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"):
def tune_cutlass_kernels(
mod, sm, use_3xtf32=True, profile_all=True, use_multiprocessing=False, tmp_dir="./tmp"
):
"""Given a module partitioned for CUTLASS offloading, profile each workload to select which
kernels to emit.
Expand Down Expand Up @@ -258,6 +322,8 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
new_attrs.update(func.attrs)
arg0_shape = new_attrs["arg0_shape"]
arg1_shape = new_attrs["arg1_shape"]
arg0_dtype = new_attrs["arg0_dtype"]
arg1_dtype = new_attrs["arg1_dtype"]

if "conv2d" in op_type:
new_attrs["padding"] = annotator.op_attrs.padding
Expand All @@ -273,6 +339,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
annotator.op_attrs.strides,
annotator.op_attrs.dilation,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
use_multiprocessing,
)
Expand All @@ -285,6 +354,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
arg0_shape,
arg1_shape,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
use_multiprocessing,
)
Expand All @@ -297,6 +369,9 @@ def tune_cutlass_kernels(mod, sm, profile_all=True, use_multiprocessing=False, t
arg0_shape,
arg1_shape,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all,
use_multiprocessing,
)
Expand Down
29 changes: 22 additions & 7 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,13 @@ def __init__(self, sm, cutlass_path, binary_path):
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
self.cache = {}

def get_default(self, op_type, out_dtype):
gemm_profile_result = self.gemm_profiler.get_default(op_type, out_dtype)
def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32):
"""Return the default kernel for the requested architecture.
For now, the default kernel was picked arbitrary.
"""
gemm_profile_result = self.gemm_profiler.get_default(
op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32
)
tile_description = gemm_profile_result["tile_description"]
alignment = gemm_profile_result["alignment"]
data_type = gemm_profile_result["data_type"]
Expand All @@ -165,9 +170,10 @@ def get_default(self, op_type, out_dtype):

def check_align(self, op_name, C, K):
"""Filter out kernels that cannot be supported."""
aligns = re.findall(r"align[1|2|4|8]", op_name)
assert len(aligns) == 1
align = int(aligns[0][-1])
match = re.match(".*_align([1-9]+)", op_name)
assert match is not None and len(match.groups()) == 1
# The same alignment is used for all axes
align = int(match.groups()[0])
return all([dim % align == 0 for dim in [C, K]])

def select_op(
Expand All @@ -178,6 +184,9 @@ def select_op(
stride,
dilation,
out_dtype,
data_dtype,
weight_dtype,
use_3xtf32,
profile_all=True,
use_multiprocessing=False,
):
Expand Down Expand Up @@ -207,9 +216,9 @@ def select_op(
return self.cache[workload]

ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype,
op_creator=enumerate_conv2d_operators,
out_dtype, data_dtype, weight_dtype, enumerate_conv2d_operators, use_3xtf32
)

ops = list(filter(lambda op: self.check_align(op["name"], IC, OC), ops))

if profile_all:
Expand Down Expand Up @@ -240,6 +249,9 @@ def profile(
stride,
dilation,
out_dtype,
data_dtype,
weight_dtype,
use_3xtf32=True,
profile_all=True,
use_multiprocessing=False,
):
Expand All @@ -254,6 +266,9 @@ def profile(
stride,
dilation,
out_dtype,
data_dtype,
weight_dtype,
use_3xtf32,
profile_all=profile_all,
use_multiprocessing=use_multiprocessing,
)
Expand Down
66 changes: 53 additions & 13 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,18 @@ def enumerate_gemm_operators(
# TODO(masahi): A sensible way to pick reasonable default kernels
DEFAULT_KERNELS = {
75: {
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
("float16", "float16"): "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
("float16", "float32"): "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
},
# align1 variants do not seem to be available for sm80
80: {
"float16": "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
"float32": "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
("float16", "float16"): "cutlass_tensorop_h1688gemm_128x64_32x2_tn_align1",
("float16", "float32"): "cutlass_tensorop_s1688gemm_f16_64x64_32x2_tn_align1",
# two kernels for tf32 and 3xtf32
("float32", "float32"): (
"cutlass_tensorop_s1688gemm_128x64_32x3_tn_align1",
"cutlass_tensorop_s1688gemm_64x64_16x3_tn_align1",
),
},
}

Expand All @@ -147,21 +152,31 @@ def __init__(self, sm, cutlass_path, binary_path):

def check_align(self, op_name, M, N, K):
"""Filter out kernels that cannot be supported."""
aligns = re.findall(r"align[1|2|4|8]", op_name)
assert len(aligns) == 1
match = re.match(".*_align([1-9]+)", op_name)
assert match is not None and len(match.groups()) == 1
# The same alignment is used for all axes
align = int(aligns[0][-1])
align = int(match.groups()[0])
# TODO(masahi): CUTLASS alignment check on gemm kernels is too restrictive.
# See https://github.com/NVIDIA/cutlass/issues/362.
# When the above issue is resolved, we can remove the alignment check on M below.
return all([dim % align == 0 for dim in [M, N, K]])

def get_default(self, op_type, out_dtype, batched=False):
def get_default(
self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32=True, batched=False
):
"""Return the default kernel for the requested architecture.
For now, the default kernel was picked arbitrary.
"""
ops = GENERATOR_FUNC_TABLE[self.sm](out_dtype, op_creator=enumerate_gemm_operators)
default_kernel_name = DEFAULT_KERNELS[self.sm][out_dtype]
ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype, arg0_dtype, arg1_dtype, enumerate_gemm_operators, use_3xtf32
)
default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)]

if arg0_dtype == "float32":
default_kernel_name = (
default_kernel_name[0] if not use_3xtf32 else default_kernel_name[1]
)

filtered = list(filter(lambda op: op["name"] == default_kernel_name, ops))
assert len(filtered) == 1
op = filtered[0]
Expand All @@ -176,7 +191,18 @@ def get_default(self, op_type, out_dtype, batched=False):
op.update({"name": name, "opdef": opdef})
return op

def select_op(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=False):
def select_op(
self,
M,
N,
K,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all=True,
use_multiprocessing=False,
):
"""
Profile and select the best kernel from candidate kernels.
See the documentation for the profile method below.
Expand All @@ -187,7 +213,10 @@ def select_op(self, M, N, K, out_dtype, profile_all=True, use_multiprocessing=Fa

ops = GENERATOR_FUNC_TABLE[self.sm](
out_dtype,
op_creator=enumerate_gemm_operators,
arg0_dtype,
arg1_dtype,
enumerate_gemm_operators,
use_3xtf32=use_3xtf32,
)
ops = list(filter(lambda op: self.check_align(op["name"], M, N, K), ops))

Expand All @@ -212,6 +241,9 @@ def profile(
N,
K,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32=True,
profile_all=True,
use_multiprocessing=False,
batched=False,
Expand All @@ -221,7 +253,15 @@ def profile(
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
op = self.select_op(
M, N, K, out_dtype, profile_all=profile_all, use_multiprocessing=use_multiprocessing
M,
N,
K,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
profile_all=profile_all,
use_multiprocessing=use_multiprocessing,
)

name, opdef = create_gemm_operator_with_epilogue(
Expand Down
Loading