Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUTLASS] Add parallel split-k support to wgrad #10185

Merged
merged 4 commits into from
Feb 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=invalid-name, dangerous-default-value
"""Driver for partitioning and building a Relay module for CUTLASS offload."""
import logging
import os
Expand Down Expand Up @@ -238,6 +238,7 @@ def handle_conv2d(
data_dtype,
weight_dtype,
use_3xtf32,
split_k_slices,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
Expand Down Expand Up @@ -269,6 +270,7 @@ def handle_conv2d(
weight_dtype,
use_3xtf32,
conv_kind,
split_k_slices,
profile_all_alignments,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
Expand All @@ -288,6 +290,7 @@ def tune_cutlass_kernels(
mod,
sm,
use_3xtf32=True,
split_k_slices=[1],
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
Expand All @@ -309,6 +312,14 @@ def tune_cutlass_kernels(
Wheter or not use slower but very accurate (compared to tf32) 3xtf32 mode for
fp32 inputs on tensorcore.

split_k_slices : list of int
Split factor candidates for split-K GEMM. If split-K > 1, the GEMM K-loop is computed in
parallel accross split-K blocks, and a seperate global reduction kernel is launched to
accumulate partial reductions. The profiler will pick the best split-k factor from the
given candidate list. Note that the larger split-K factor requires a larger workspace.
Currently, parallel split-k has been tested only for wgrad. For GEMM and other conv2d
kinds, split_k_slices is ignored.

profile_all_alignments : bool
When True, profile all kernal variants with smaller alignments than the largest possible.

Expand Down Expand Up @@ -380,6 +391,7 @@ def tune_cutlass_kernels(
arg0_dtype,
arg1_dtype,
use_3xtf32,
split_k_slices,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
Expand Down
82 changes: 76 additions & 6 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
stride_support,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity1,
split_k_slices=1,
):
self.operation_kind = OperationKind.Conv2d
self.arch = arch
Expand All @@ -48,6 +49,7 @@ def __init__(
self.iterator_algorithm = iterator_algorithm
self.stride_support = stride_support
self.swizzling_functor = swizzling_functor
self.split_k_slices = split_k_slices

def accumulator_type(self):
return self.tile_description.math_instruction.element_accumulator
Expand Down Expand Up @@ -127,6 +129,9 @@ def procedural_name(self):
"_${layout}_align${alignment}"
)

if self.split_k_slices > 1:
configuration_name += "_splitk%d" % self.split_k_slices

return substitute_template(
configuration_name,
{
Expand Down Expand Up @@ -172,6 +177,14 @@ def __init__(self):
${unary_op}
>"""

self.epilogue_wgrad = """
${epilogue_functor}<
${element_c},
4,
float,
float
>"""

self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name} =
Expand All @@ -197,9 +210,31 @@ def __init__(self):
${align_a},
${align_b}
>::Kernel;

${reduction}
"""

self.reduction_template = """
using EpilogueOutputOp = ${epilogue};
using ReductionOp = cutlass::reduction::thread::ReduceAdd<
${element_accumulator},
${element_accumulator},
EpilogueOutputOp::kCount
>;

using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<
cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
EpilogueOutputOp,
ReductionOp
>;

using ReductionDevice = cutlass::reduction::device::ReduceSplitK<ReductionKernel>;
using ReductionStrideIndex = typename ReductionDevice::StrideIndex;
"""

def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
def emit(
self, operation, no_beta_scaling=False, residual_block_info=False, emit_reduction=False
):
"""Instantiate a Conv2d kernel from given `operation`."""
warp_shape = [
int(
Expand All @@ -214,6 +249,31 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
/ DataTypeSize[operation.C.element]
)

element_c = operation.C.element
use_split_k_wgrad = operation.conv_kind == ConvKind.Wgrad and operation.split_k_slices > 1
# Gemm output always fp32 in wgrad with split k
element_c_gemm = DataType.f32 if use_split_k_wgrad else element_c

if emit_reduction:
epilogue_reduction = substitute_template(
self.epilogue_wgrad,
{
"epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
"element_c": DataTypeTag[element_c],
},
)
reduction = substitute_template(
self.reduction_template,
{
"epilogue": epilogue_reduction,
"operation_name": operation.procedural_name(),
"element_accumulator": DataTypeTag[operation.accumulator_type()],
},
)
gemm_template = substitute_template(self.template, {"reduction": reduction})
else:
gemm_template = substitute_template(self.template, {"reduction": ""})

values = {
"operation_name": operation.procedural_name(),
"conv_kind": ConvKindTag[operation.conv_kind],
Expand All @@ -222,7 +282,7 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
"layout_a": LayoutTag[operation.A.layout],
"element_b": DataTypeTag[operation.B.element],
"layout_b": LayoutTag[operation.B.layout],
"element_c": DataTypeTag[operation.C.element],
"element_c": DataTypeTag[element_c_gemm],
"layout_c": LayoutTag[operation.C.layout],
"element_accumulator": DataTypeTag[operation.accumulator_type()],
"opcode_class": OpcodeClassTag[
Expand Down Expand Up @@ -262,9 +322,19 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
"conv_kernel_postfix": "",
}

if residual_block_info:
if use_split_k_wgrad:
# Even if the output is fp16, gemm output is always fp32 for split k wgrad.
epilogue_gemm = substitute_template(
self.epilogue_wgrad,
{
"epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
"element_c": "float",
},
)
template = substitute_template(gemm_template, {"epilogue": epilogue_gemm})
elif residual_block_info:
template = substitute_template(
self.template, {"epilogue": self.epilogue_residual_block}
gemm_template, {"epilogue": self.epilogue_residual_block}
)
values.update(
{
Expand All @@ -276,9 +346,9 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
)
elif no_beta_scaling:
template = substitute_template(
self.template, {"epilogue": self.epilogue_no_beta_scaling}
gemm_template, {"epilogue": self.epilogue_no_beta_scaling}
)
else:
template = substitute_template(self.template, {"epilogue": self.epilogue_default})
template = substitute_template(gemm_template, {"epilogue": self.epilogue_default})

return substitute_template(template, values)
58 changes: 51 additions & 7 deletions python/tvm/contrib/cutlass/conv2d_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,41 @@
# pylint: disable=import-outside-toplevel, invalid-name
"""Instantiate a C++ source for profiling CUTLASS kernels."""

from .library import DataTypeTag


class Conv2dProfilerEmitter(object):
"""Emit a C++ source for profiling CUTLASS kernels."""

def __init__(self):
from jinja2 import Template

self.reduction = """
ReductionDevice reduction_op;
static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemm::kConvolutionalOperator;
typename ReductionDevice::Arguments reduction_args(
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(),
problem_size.split_k_slices,
cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size),
{
reinterpret_cast<ImplicitGemm::ElementC*> (workspace.get()),
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
},
{
tensor_d.device_data(),
ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
},
{
tensor_c.device_data(),
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
},
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)}
);

reduction_op.initialize(reduction_args, nullptr);
reduction_op();
"""

self.template = Template(
"""
#include <iostream>
Expand All @@ -35,6 +63,8 @@ def __init__(self):
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/reduction/device/reduce_split_k.h"
#include "cutlass/reduction/thread/reduction_operators.h"

#define CUTLASS_CHECK(status) \
{ \
Expand Down Expand Up @@ -88,10 +118,11 @@ def __init__(self):
};

double profile_convolution(Options const &options) {
using ElementOutput = typename ImplicitGemm::ElementC;
using ElementOutput = {{ElementOutput}};
using ElementInputA = typename ImplicitGemm::ElementA;
using ElementInputB = typename ImplicitGemm::ElementB;

int split_k_slices = {{SplitK}};
cutlass::conv::Conv2dProblemSize problem_size(
options.input_size,
options.filter_size,
Expand All @@ -100,28 +131,34 @@ def __init__(self):
options.dilation,
options.output_size(),
cutlass::conv::Mode::kCrossCorrelation,
1
split_k_slices
);

auto conv_kind = ImplicitGemm::kConvolutionalOperator;
auto a_extent = implicit_gemm_tensor_a_extent(conv_kind, problem_size);
auto b_extent = implicit_gemm_tensor_b_extent(conv_kind, problem_size);
auto c_extent = implicit_gemm_tensor_c_extent(conv_kind, problem_size);

using LayoutC = typename ImplicitGemm::LayoutC;
cutlass::HostTensor<ElementInputA, typename ImplicitGemm::LayoutA> tensor_a(a_extent);
cutlass::HostTensor<ElementInputB, typename ImplicitGemm::LayoutB> tensor_b(b_extent);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_c(c_extent);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_ref_c(c_extent);
cutlass::HostTensor<ElementOutput, LayoutC> tensor_d(c_extent);
cutlass::HostTensor<ImplicitGemm::ElementC, LayoutC> tensor_c_gemm(c_extent);

using ElementComputeEpilogue = typename ImplicitGemm::ElementCompute;

cutlass::conv::SplitKMode const split_k_mode = split_k_slices > 1 ?
cutlass::conv::SplitKMode::kParallel : cutlass::conv::SplitKMode::kSerial;

typename ImplicitGemm::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_c.device_ref(),
tensor_c_gemm.device_ref(),
tensor_c_gemm.device_ref(),
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)},
split_k_mode,
};

ImplicitGemm implicit_gemm_op;
Expand All @@ -144,6 +181,7 @@ def __init__(self):
for (int iteration = 0; iteration < 100; ++iteration) {
auto status = implicit_gemm_op();
CUTLASS_CHECK(status);
{{Reduction}}
}

cudaEventRecord(events[1]);
Expand All @@ -166,6 +204,12 @@ def __init__(self):
"""
)

def emit(self, op_def, op_name):
src = self.template.render(OperatorDef=op_def, OperatorName=op_name)
def emit(self, op_def, op_name, element_output, split_k_slices=1):
src = self.template.render(
OperatorDef=op_def,
OperatorName=op_name,
ElementOutput=DataTypeTag[element_output],
SplitK=split_k_slices,
Reduction=self.reduction if split_k_slices > 1 else "",
)
return src
Loading