Skip to content

Commit

Permalink
Add split-k support for wgrad
Browse files Browse the repository at this point in the history
commit 43820d5
Author: Masahiro Masuda <[email protected]>
Date:   Sun Feb 6 10:07:34 2022 +0900

    fix and add doc

commit 446a95b
Author: Masahiro Masuda <[email protected]>
Date:   Sun Feb 6 09:48:38 2022 +0900

    dw conv2d properly supported for wgrad

commit adc4e22
Author: Masahiro Masuda <[email protected]>
Date:   Sat Feb 5 16:32:42 2022 +0900

    fix overwriting template

commit 040eab0
Author: Masahiro Masuda <[email protected]>
Date:   Sat Feb 5 16:06:27 2022 +0900

    black

commit e5a07c2
Author: Masahiro Masuda <[email protected]>
Date:   Sat Feb 5 16:03:10 2022 +0900

    add reduction in profiler

commit be89334
Author: Masahiro Masuda <[email protected]>
Date:   Sat Feb 5 06:58:03 2022 +0900

    adding split k reduction to conv2d profiler

commit ae09b0f
Author: Masahiro Masuda <[email protected]>
Date:   Fri Feb 4 11:52:59 2022 +0900

    fixed conv2d_backward_weight typerel for dw conv2d

    commit 16fe531
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 12:59:22 2022 +0900

        wip

    commit 2167c25
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 04:22:19 2022 +0900

        fix conv2d type rel for depth wise and grouped conv2d

commit 14b12e5
Author: Masahiro Masuda <[email protected]>
Date:   Fri Feb 4 05:01:03 2022 +0900

    remove split_k.py

commit b141271
Author: Masahiro Masuda <[email protected]>
Date:   Fri Feb 4 04:48:21 2022 +0900

    workaround for invalid split_k_slice

commit 6e4c7e1
Author: Masahiro Masuda <[email protected]>
Date:   Fri Feb 4 02:43:58 2022 +0900

    support split k in profiler

commit 2eb1cf4
Author: Masahiro Masuda <[email protected]>
Date:   Fri Feb 4 02:31:03 2022 +0900

    improvement

commit 0bce8f3
Author: Masahiro Masuda <[email protected]>
Date:   Thu Feb 3 18:20:12 2022 +0900

    fixed for fp16 output

commit 30df1bd
Author: Masahiro Masuda <[email protected]>
Date:   Thu Feb 3 17:50:33 2022 +0900

    fp32 output works

commit 7a51995
Author: Masahiro Masuda <[email protected]>
Date:   Thu Feb 3 14:30:22 2022 +0900

    fix

commit 4a383e2
Author: Masahiro Masuda <[email protected]>
Date:   Thu Feb 3 14:05:24 2022 +0900

    update c++ codegen

commit 6206e38
Author: Masahiro Masuda <[email protected]>
Date:   Thu Feb 3 13:46:05 2022 +0900

    wip

commit 0ece49b
Author: Masahiro Masuda <[email protected]>
Date:   Thu Feb 3 03:05:21 2022 +0900

    wip

commit 08a6147
Author: Masahiro Masuda <[email protected]>
Date:   Wed Feb 2 13:10:21 2022 +0900

    test worked with fp32 output

commit 084d5c4
Author: Masahiro Masuda <[email protected]>
Date:   Wed Feb 2 12:35:18 2022 +0900

    fix compile error for fprop

commit 31f2543
Author: Masahiro Masuda <[email protected]>
Date:   Wed Feb 2 12:18:06 2022 +0900

    compiled

commit c2098e7
Author: Masahiro Masuda <[email protected]>
Date:   Wed Feb 2 11:11:43 2022 +0900

    wip
  • Loading branch information
masahi committed Feb 6, 2022
1 parent a145850 commit ae2e718
Show file tree
Hide file tree
Showing 7 changed files with 335 additions and 81 deletions.
10 changes: 10 additions & 0 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ def handle_conv2d(
data_dtype,
weight_dtype,
use_3xtf32,
split_k_slices,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
Expand Down Expand Up @@ -269,6 +270,7 @@ def handle_conv2d(
weight_dtype,
use_3xtf32,
conv_kind,
split_k_slices,
profile_all_alignments,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
Expand All @@ -288,6 +290,7 @@ def tune_cutlass_kernels(
mod,
sm,
use_3xtf32=True,
split_k_slices=[1],
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
Expand All @@ -309,6 +312,12 @@ def tune_cutlass_kernels(
Wheter or not use slower but very accurate (compared to tf32) 3xtf32 mode for
fp32 inputs on tensorcore.
split_k_slices : list of int
Split factor candidates for split-K GEMM. If split-K > 1, the GEMM K-loop is computed in
parallel accross split-K blocks, and a seperate global reduction kernel is launched to
accumulate partial reductions. The profiler will pick the best split-k factor from the
given candidate list. Note that the larger split-K factor requires a larger workspace.
profile_all_alignments : bool
When True, profile all kernal variants with smaller alignments than the largest possible.
Expand Down Expand Up @@ -380,6 +389,7 @@ def tune_cutlass_kernels(
arg0_dtype,
arg1_dtype,
use_3xtf32,
split_k_slices,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
Expand Down
82 changes: 76 additions & 6 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
stride_support,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity1,
split_k_slices=1,
):
self.operation_kind = OperationKind.Conv2d
self.arch = arch
Expand All @@ -48,6 +49,7 @@ def __init__(
self.iterator_algorithm = iterator_algorithm
self.stride_support = stride_support
self.swizzling_functor = swizzling_functor
self.split_k_slices = split_k_slices

def accumulator_type(self):
return self.tile_description.math_instruction.element_accumulator
Expand Down Expand Up @@ -127,6 +129,9 @@ def procedural_name(self):
"_${layout}_align${alignment}"
)

if self.split_k_slices > 1:
configuration_name += "_splitk%d" % self.split_k_slices

return substitute_template(
configuration_name,
{
Expand Down Expand Up @@ -172,6 +177,14 @@ def __init__(self):
${unary_op}
>"""

self.epilogue_wgrad = """
${epilogue_functor}<
${element_c},
4,
float,
float
>"""

self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name} =
Expand All @@ -197,9 +210,31 @@ def __init__(self):
${align_a},
${align_b}
>::Kernel;
${reduction}
"""

self.reduction_template = """
using EpilogueOutputOp = ${epilogue};
using ReductionOp = cutlass::reduction::thread::ReduceAdd<
${element_accumulator},
${element_accumulator},
EpilogueOutputOp::kCount
>;
using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<
cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
EpilogueOutputOp,
ReductionOp
>;
using ReductionDevice = cutlass::reduction::device::ReduceSplitK<ReductionKernel>;
using ReductionStrideIndex = typename ReductionDevice::StrideIndex;
"""

def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
def emit(
self, operation, no_beta_scaling=False, residual_block_info=False, emit_reduction=False
):
"""Instantiate a Conv2d kernel from given `operation`."""
warp_shape = [
int(
Expand All @@ -214,6 +249,31 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
/ DataTypeSize[operation.C.element]
)

element_c = operation.C.element
use_split_k_wgrad = operation.conv_kind == ConvKind.Wgrad and operation.split_k_slices > 1
# Gemm output always fp32 in wgrad with split k
element_c_gemm = DataType.f32 if use_split_k_wgrad else element_c

if emit_reduction:
epilogue_reduction = substitute_template(
self.epilogue_wgrad,
{
"epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
"element_c": DataTypeTag[element_c],
},
)
reduction = substitute_template(
self.reduction_template,
{
"epilogue": epilogue_reduction,
"operation_name": operation.procedural_name(),
"element_accumulator": DataTypeTag[operation.accumulator_type()],
},
)
gemm_template = substitute_template(self.template, {"reduction": reduction})
else:
gemm_template = substitute_template(self.template, {"reduction": ""})

values = {
"operation_name": operation.procedural_name(),
"conv_kind": ConvKindTag[operation.conv_kind],
Expand All @@ -222,7 +282,7 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
"layout_a": LayoutTag[operation.A.layout],
"element_b": DataTypeTag[operation.B.element],
"layout_b": LayoutTag[operation.B.layout],
"element_c": DataTypeTag[operation.C.element],
"element_c": DataTypeTag[element_c_gemm],
"layout_c": LayoutTag[operation.C.layout],
"element_accumulator": DataTypeTag[operation.accumulator_type()],
"opcode_class": OpcodeClassTag[
Expand Down Expand Up @@ -262,9 +322,19 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
"conv_kernel_postfix": "",
}

if residual_block_info:
if use_split_k_wgrad:
# Even if the output is fp16, gemm output is always fp32 for split k wgrad.
epilogue_gemm = substitute_template(
self.epilogue_wgrad,
{
"epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
"element_c": "float",
},
)
template = substitute_template(gemm_template, {"epilogue": epilogue_gemm})
elif residual_block_info:
template = substitute_template(
self.template, {"epilogue": self.epilogue_residual_block}
gemm_template, {"epilogue": self.epilogue_residual_block}
)
values.update(
{
Expand All @@ -276,9 +346,9 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
)
elif no_beta_scaling:
template = substitute_template(
self.template, {"epilogue": self.epilogue_no_beta_scaling}
gemm_template, {"epilogue": self.epilogue_no_beta_scaling}
)
else:
template = substitute_template(self.template, {"epilogue": self.epilogue_default})
template = substitute_template(gemm_template, {"epilogue": self.epilogue_default})

return substitute_template(template, values)
57 changes: 50 additions & 7 deletions python/tvm/contrib/cutlass/conv2d_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,41 @@
# pylint: disable=import-outside-toplevel, invalid-name
"""Instantiate a C++ source for profiling CUTLASS kernels."""

from .library import DataTypeTag


class Conv2dProfilerEmitter(object):
"""Emit a C++ source for profiling CUTLASS kernels."""

def __init__(self):
from jinja2 import Template

self.reduction = """
ReductionDevice reduction_op;
static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemm::kConvolutionalOperator;
typename ReductionDevice::Arguments reduction_args(
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(),
problem_size.split_k_slices,
cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size),
{
reinterpret_cast<ImplicitGemm::ElementC*> (workspace.get()),
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
},
{
tensor_d.device_data(),
ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
},
{
tensor_c.device_data(),
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
},
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)}
);
reduction_op.initialize(reduction_args, nullptr);
reduction_op();
"""

self.template = Template(
"""
#include <iostream>
Expand All @@ -35,6 +63,8 @@ def __init__(self):
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/reduction/device/reduce_split_k.h"
#include "cutlass/reduction/thread/reduction_operators.h"
#define CUTLASS_CHECK(status) \
{ \
Expand Down Expand Up @@ -88,10 +118,11 @@ def __init__(self):
};
double profile_convolution(Options const &options) {
using ElementOutput = typename ImplicitGemm::ElementC;
using ElementOutput = {{ElementOutput}};
using ElementInputA = typename ImplicitGemm::ElementA;
using ElementInputB = typename ImplicitGemm::ElementB;
int split_k_slices = {{SplitK}};
cutlass::conv::Conv2dProblemSize problem_size(
options.input_size,
options.filter_size,
Expand All @@ -100,7 +131,7 @@ def __init__(self):
options.dilation,
options.output_size(),
cutlass::conv::Mode::kCrossCorrelation,
1
split_k_slices
);
auto conv_kind = ImplicitGemm::kConvolutionalOperator;
Expand All @@ -111,17 +142,22 @@ def __init__(self):
cutlass::HostTensor<ElementInputA, typename ImplicitGemm::LayoutA> tensor_a(a_extent);
cutlass::HostTensor<ElementInputB, typename ImplicitGemm::LayoutB> tensor_b(b_extent);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_c(c_extent);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_ref_c(c_extent);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_d(c_extent);
cutlass::HostTensor<ImplicitGemm::ElementC, typename ImplicitGemm::LayoutC> tensor_c_gemm(c_extent);
using ElementComputeEpilogue = typename ImplicitGemm::ElementCompute;
cutlass::conv::SplitKMode const split_k_mode = split_k_slices > 1 ?
cutlass::conv::SplitKMode::kParallel : cutlass::conv::SplitKMode::kSerial;
typename ImplicitGemm::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_c.device_ref(),
tensor_c_gemm.device_ref(),
tensor_c_gemm.device_ref(),
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)},
split_k_mode,
};
ImplicitGemm implicit_gemm_op;
Expand All @@ -144,6 +180,7 @@ def __init__(self):
for (int iteration = 0; iteration < 100; ++iteration) {
auto status = implicit_gemm_op();
CUTLASS_CHECK(status);
{{Reduction}}
}
cudaEventRecord(events[1]);
Expand All @@ -166,6 +203,12 @@ def __init__(self):
"""
)

def emit(self, op_def, op_name):
src = self.template.render(OperatorDef=op_def, OperatorName=op_name)
def emit(self, op_def, op_name, element_output, split_k_slices=1):
src = self.template.render(
OperatorDef=op_def,
OperatorName=op_name,
ElementOutput=DataTypeTag[element_output],
SplitK=split_k_slices,
Reduction=self.reduction if split_k_slices > 1 else "",
)
return src
Loading

0 comments on commit ae2e718

Please sign in to comment.