Skip to content

Commit

Permalink
[CUTLASS] Conv2d dgrad (apache#10110)
Browse files Browse the repository at this point in the history
* add conv2d transpose nhwc cudnn test

* support conv2d transpose nhwc direct offload to cudnn

* add cutlass dgrad support

* remove unused arg

* allow target none

* fix beta initiaization condition

* disable dynamic dense fp16 test since it fails on cuda 11.6
  • Loading branch information
masahi authored and ylc committed Feb 16, 2022
1 parent c4f343d commit b000c36
Show file tree
Hide file tree
Showing 14 changed files with 474 additions and 115 deletions.
31 changes: 27 additions & 4 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tvm.contrib.nvcc import find_cuda_path, get_cuda_version
from .gen_gemm import CutlassGemmProfiler
from .gen_conv2d import CutlassConv2DProfiler
from .library import ConvKind

logger = logging.getLogger("cutlass")

Expand Down Expand Up @@ -86,7 +87,7 @@ def visit_call(self, call):
self.signature["ret_dtype"] = op.ret_type.dtype
self.visit(op.body)

if str(op) == "nn.conv2d":
if str(op) in ["nn.conv2d", "nn.conv2d_transpose", "nn.conv2d_backward_weight"]:
self.op_attrs = call.attrs

for arg in call.args:
Expand Down Expand Up @@ -242,8 +243,17 @@ def handle_conv2d(
use_multiprocessing,
):
"""Profile and select a kernel for conv2d op workload."""
if "conv2d_transpose" in op_type:
conv_kind = ConvKind.Dgrad
elif "backward_weight" in op_type:
conv_kind = ConvKind.Wgrad
else:
conv_kind = ConvKind.Fprop

if any(isinstance(s, tvm.tir.Any) for s in d_shape):
out = cutlass_profiler.get_default(op_type, out_dtype, data_dtype, weight_dtype, use_3xtf32)
out = cutlass_profiler.get_default(
op_type, out_dtype, data_dtype, weight_dtype, use_3xtf32, conv_kind, strides
)
name, cutlass_op_def = out["name"], out["opdef"]
logger.info("Picked the default kernel %s", name)
else:
Expand All @@ -258,6 +268,7 @@ def handle_conv2d(
data_dtype,
weight_dtype,
use_3xtf32,
conv_kind,
profile_all_alignments,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
Expand Down Expand Up @@ -329,6 +340,7 @@ def tune_cutlass_kernels(
if "cutlass" in fun_name:
num_cutlass_partition += 1
annotator.visit(func)
out_shape = annotator.signature["ret_shape"]
out_dtype = annotator.signature["ret_dtype"]
op_type = annotator.signature["op_type"]

Expand All @@ -344,12 +356,23 @@ def tune_cutlass_kernels(
new_attrs["padding"] = annotator.op_attrs.padding
new_attrs["strides"] = annotator.op_attrs.strides
new_attrs["dilation"] = annotator.op_attrs.dilation

if "conv2d_transpose" in op_type:
d_shape = out_shape
w_shape = arg1_shape
elif "conv2d_backward_weight" in op_type:
d_shape = arg1_shape
w_shape = out_shape
else:
d_shape = arg0_shape
w_shape = arg1_shape

new_attrs.update(
handle_conv2d(
conv2d_profiler,
op_type,
arg0_shape,
arg1_shape,
d_shape,
w_shape,
annotator.op_attrs.padding,
annotator.op_attrs.strides,
annotator.op_attrs.dilation,
Expand Down
18 changes: 13 additions & 5 deletions python/tvm/contrib/cutlass/conv2d_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def __init__(self):
#include <iostream>
#include "cutlass/cutlass.h"
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/conv/kernel/default_conv2d_wgrad.h"
#include "cutlass/conv/kernel/default_conv2d_dgrad.h"
#include "cutlass/conv/device/implicit_gemm_convolution.h"
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
Expand Down Expand Up @@ -89,11 +91,6 @@ def __init__(self):
using ElementOutput = typename ImplicitGemm::ElementC;
using ElementInputA = typename ImplicitGemm::ElementA;
using ElementInputB = typename ImplicitGemm::ElementB;
auto oshape = options.output_size();
cutlass::HostTensor<ElementInputA, typename ImplicitGemm::LayoutA> tensor_a(options.input_size);
cutlass::HostTensor<ElementInputB, typename ImplicitGemm::LayoutB> tensor_b(options.filter_size);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_c(oshape);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_ref_c(oshape);
cutlass::conv::Conv2dProblemSize problem_size(
options.input_size,
Expand All @@ -106,7 +103,18 @@ def __init__(self):
1
);
auto conv_kind = ImplicitGemm::kConvolutionalOperator;
auto a_extent = implicit_gemm_tensor_a_extent(conv_kind, problem_size);
auto b_extent = implicit_gemm_tensor_b_extent(conv_kind, problem_size);
auto c_extent = implicit_gemm_tensor_c_extent(conv_kind, problem_size);
cutlass::HostTensor<ElementInputA, typename ImplicitGemm::LayoutA> tensor_a(a_extent);
cutlass::HostTensor<ElementInputB, typename ImplicitGemm::LayoutB> tensor_b(b_extent);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_c(c_extent);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_ref_c(c_extent);
using ElementComputeEpilogue = typename ImplicitGemm::ElementCompute;
typename ImplicitGemm::Arguments arguments{
problem_size,
tensor_a.device_ref(),
Expand Down
71 changes: 62 additions & 9 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Conv2d kernel generator and profiler for CUTLASS."""
from functools import partial
from .conv2d_operation import Conv2dOperation, EmitConv2dInstance
from .gen_gemm import CutlassGemmProfiler
from .conv2d_profiler import Conv2dProfilerEmitter
Expand All @@ -32,7 +33,13 @@


def create_conv2d_operator_with_epilogue(
op_type, tile_description, data_type, alignment, swizzling_functor
conv_kind,
stride_support,
op_type,
tile_description,
data_type,
alignment,
swizzling_functor,
):
"""
Instantiate a cutlass kernel from the given configuration,
Expand Down Expand Up @@ -72,15 +79,15 @@ def create_conv2d_operator_with_epilogue(
C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)

op = Conv2dOperation(
ConvKind.Fprop,
conv_kind,
IteratorAlgorithm.Optimized,
tile_description.minimum_compute_capability,
tile_description,
A,
B,
C,
element_epilogue,
StrideSupport.Strided,
stride_support,
epilogue,
swizzling_functor,
)
Expand All @@ -94,6 +101,8 @@ def create_conv2d_operator_with_epilogue(


def enumerate_conv2d_operators(
conv_kind,
stride_support,
tile_descriptions,
data_type,
alignment_constraints,
Expand All @@ -107,6 +116,9 @@ def enumerate_conv2d_operators(

element_a, element_b, element_c, element_epilogue = data_type

if conv_kind == ConvKind.Dgrad and stride_support == StrideSupport.Strided:
swizzling_functor = SwizzlingFunctor.StridedDgradIdentity1

for tile in tile_descriptions:
for alignment in alignment_constraints:

Expand All @@ -115,15 +127,15 @@ def enumerate_conv2d_operators(
C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)

op = Conv2dOperation(
ConvKind.Fprop,
conv_kind,
IteratorAlgorithm.Optimized,
tile.minimum_compute_capability,
tile,
A,
B,
C,
element_epilogue,
StrideSupport.Strided,
stride_support,
EpilogueFunctor.LinearCombination,
swizzling_functor,
)
Expand Down Expand Up @@ -152,7 +164,16 @@ def __init__(self, sm, cutlass_path, binary_path):
self.engine = ProfilerEngine(sm, cutlass_path, binary_path)
self.cache = {}

def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32):
def get_default(
self,
op_type,
out_dtype,
arg0_dtype,
arg1_dtype,
use_3xtf32,
conv_kind=ConvKind.Fprop,
stride=(1, 1),
):
"""Return the default kernel for the requested architecture.
For now, the default kernel was picked arbitrary.
"""
Expand All @@ -162,8 +183,21 @@ def get_default(self, op_type, out_dtype, arg0_dtype, arg1_dtype, use_3xtf32):
tile_description = gemm_profile_result["tile_description"]
alignment = gemm_profile_result["alignment"]
data_type = gemm_profile_result["data_type"]
stride_support = StrideSupport.Strided if stride[0] > 1 else StrideSupport.Unity

if conv_kind == ConvKind.Dgrad and stride_support == StrideSupport.Strided:
swizzling_functor = SwizzlingFunctor.StridedDgradIdentity1
else:
swizzling_functor = SwizzlingFunctor.Identity4

name, opdef = create_conv2d_operator_with_epilogue(
op_type, tile_description, data_type, alignment, SwizzlingFunctor.Identity4
conv_kind,
stride_support,
op_type,
tile_description,
data_type,
alignment,
swizzling_functor,
)
return {"name": name, "opdef": opdef}

Expand All @@ -178,6 +212,8 @@ def select_op(
data_dtype,
weight_dtype,
use_3xtf32,
conv_kind,
stride_support,
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
Expand All @@ -188,6 +224,7 @@ def select_op(
"""
N, H, W, IC = d_shape
OC, R, S, _ = w_shape

workload = (
N,
H,
Expand All @@ -211,7 +248,7 @@ def select_op(
out_dtype,
data_dtype,
weight_dtype,
enumerate_conv2d_operators,
partial(enumerate_conv2d_operators, conv_kind, stride_support),
lambda align: all([dim % align == 0 for dim in [IC, OC]]),
use_3xtf32,
profile_all_alignments,
Expand Down Expand Up @@ -248,6 +285,7 @@ def profile(
data_dtype,
weight_dtype,
use_3xtf32=True,
conv_kind=ConvKind.Fprop,
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
Expand All @@ -256,6 +294,13 @@ def profile(
If find_first_valid is True, return immediately after the first applicable kernel is found.
If use_multiprocessing is True, compile all profiler executables in parallel.
"""
# Dgrad requires Unity stride when stride == (1, 1)
stride_support = (
StrideSupport.Unity
if stride[0] == 1 and stride[1] == 1 and conv_kind == ConvKind.Dgrad
else StrideSupport.Strided
)

op = self.select_op(
d_shape,
w_shape,
Expand All @@ -266,13 +311,21 @@ def profile(
data_dtype,
weight_dtype,
use_3xtf32,
conv_kind,
stride_support,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
)

name, opdef = create_conv2d_operator_with_epilogue(
op_type, op["tile_description"], op["data_type"], op["alignment"], op["swizzle_functor"]
conv_kind,
stride_support,
op_type,
op["tile_description"],
op["data_type"],
op["alignment"],
op["swizzle_functor"],
)

return name, opdef, op["runtime"]
1 change: 1 addition & 0 deletions python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def get_tile_descriptions(math_inst):
"cutlass.conv2d_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True),
"cutlass.conv2d_bias": (EpilogueFunctor.LinearCombinationBias, True),
"cutlass.conv2d": (EpilogueFunctor.LinearCombination, False),
"cutlass.conv2d_transpose": (EpilogueFunctor.LinearCombination, False),
}


Expand Down
10 changes: 10 additions & 0 deletions python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ class SwizzlingFunctor(enum.Enum):
Identity4 = enum_auto()
Identity8 = enum_auto()
Batched = enum_auto()
StridedDgradIdentity1 = enum_auto()
StridedDgradIdentity4 = enum_auto()


SwizzlingFunctorTag = {
Expand All @@ -197,20 +199,28 @@ class SwizzlingFunctor(enum.Enum):
SwizzlingFunctor.Identity4: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>",
SwizzlingFunctor.Identity8: "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>",
SwizzlingFunctor.Batched: "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle",
SwizzlingFunctor.StridedDgradIdentity1: "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<1>",
SwizzlingFunctor.StridedDgradIdentity4: "cutlass::conv::threadblock::StridedDgradIdentityThreadblockSwizzle<4>",
}


class ConvKind(enum.Enum):
Fprop = enum_auto()
Dgrad = enum_auto()
Wgrad = enum_auto()


ConvKindTag = {
ConvKind.Fprop: "cutlass::conv::Operator::kFprop",
ConvKind.Dgrad: "cutlass::conv::Operator::kDgrad",
ConvKind.Wgrad: "cutlass::conv::Operator::kWgrad",
}


ConvKindNames = {
ConvKind.Fprop: "fprop",
ConvKind.Dgrad: "dgrad",
ConvKind.Wgrad: "wgrad",
}


Expand Down
Loading

0 comments on commit b000c36

Please sign in to comment.