Skip to content

Commit

Permalink
[CUTLASS] Add parallel split-k support to wgrad (apache#10185)
Browse files Browse the repository at this point in the history
* [CUTLASS] Add split-k support to wgrad

commit 60b73a91b79d644d8c95f682eedaf47a89abba0d
Author: Masahiro Masuda <[email protected]>
Date:   Tue Feb 8 10:43:11 2022 +0900

    pylint

commit ae2e718
Author: Masahiro Masuda <[email protected]>
Date:   Sun Feb 6 14:51:52 2022 +0900

    Add split-k support for wgrad

    commit 43820d5
    Author: Masahiro Masuda <[email protected]>
    Date:   Sun Feb 6 10:07:34 2022 +0900

        fix and add doc

    commit 446a95b
    Author: Masahiro Masuda <[email protected]>
    Date:   Sun Feb 6 09:48:38 2022 +0900

        dw conv2d properly supported for wgrad

    commit adc4e22
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat Feb 5 16:32:42 2022 +0900

        fix overwriting template

    commit 040eab0
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat Feb 5 16:06:27 2022 +0900

        black

    commit e5a07c2
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat Feb 5 16:03:10 2022 +0900

        add reduction in profiler

    commit be89334
    Author: Masahiro Masuda <[email protected]>
    Date:   Sat Feb 5 06:58:03 2022 +0900

        adding split k reduction to conv2d profiler

    commit ae09b0f
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri Feb 4 11:52:59 2022 +0900

        fixed conv2d_backward_weight typerel for dw conv2d

        commit 16fe531
        Author: Masahiro Masuda <[email protected]>
        Date:   Thu Feb 3 12:59:22 2022 +0900

            wip

        commit 2167c25
        Author: Masahiro Masuda <[email protected]>
        Date:   Thu Feb 3 04:22:19 2022 +0900

            fix conv2d type rel for depth wise and grouped conv2d

    commit 14b12e5
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri Feb 4 05:01:03 2022 +0900

        remove split_k.py

    commit b141271
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri Feb 4 04:48:21 2022 +0900

        workaround for invalid split_k_slice

    commit 6e4c7e1
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri Feb 4 02:43:58 2022 +0900

        support split k in profiler

    commit 2eb1cf4
    Author: Masahiro Masuda <[email protected]>
    Date:   Fri Feb 4 02:31:03 2022 +0900

        improvement

    commit 0bce8f3
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 18:20:12 2022 +0900

        fixed for fp16 output

    commit 30df1bd
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 17:50:33 2022 +0900

        fp32 output works

    commit 7a51995
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 14:30:22 2022 +0900

        fix

    commit 4a383e2
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 14:05:24 2022 +0900

        update c++ codegen

    commit 6206e38
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 13:46:05 2022 +0900

        wip

    commit 0ece49b
    Author: Masahiro Masuda <[email protected]>
    Date:   Thu Feb 3 03:05:21 2022 +0900

        wip

    commit 08a6147
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Feb 2 13:10:21 2022 +0900

        test worked with fp32 output

    commit 084d5c4
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Feb 2 12:35:18 2022 +0900

        fix compile error for fprop

    commit 31f2543
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Feb 2 12:18:06 2022 +0900

        compiled

    commit c2098e7
    Author: Masahiro Masuda <[email protected]>
    Date:   Wed Feb 2 11:11:43 2022 +0900

        wip

commit a145850
Author: Masahiro Masuda <[email protected]>
Date:   Sun Feb 6 14:46:16 2022 +0900

    fixed for sm75

commit 6151506
Author: Masahiro Masuda <[email protected]>
Date:   Sun Feb 6 14:32:46 2022 +0900

    all tests work

commit 041c094
Author: Masahiro Masuda <[email protected]>
Date:   Sun Feb 6 14:19:09 2022 +0900

    dw conv2d properly supported for wgrad

commit 2191918
Author: Masahiro Masuda <[email protected]>
Date:   Wed Feb 2 09:14:05 2022 +0900

    wgrad tests now work under pytest

commit 78f76df
Author: Masahiro Masuda <[email protected]>
Date:   Wed Feb 2 07:31:54 2022 +0900

    run black

commit 0a82149
Author: Masahiro Masuda <[email protected]>
Date:   Wed Feb 2 06:12:39 2022 +0900

    [CUTLASS] Add wgrad support (without split-k)

* pylint

* add more doc

* more doc clarification
  • Loading branch information
masahi authored and ylc committed Feb 16, 2022
1 parent de8a0b7 commit 41135aa
Show file tree
Hide file tree
Showing 7 changed files with 340 additions and 83 deletions.
14 changes: 13 additions & 1 deletion python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=invalid-name, dangerous-default-value
"""Driver for partitioning and building a Relay module for CUTLASS offload."""
import logging
import os
Expand Down Expand Up @@ -238,6 +238,7 @@ def handle_conv2d(
data_dtype,
weight_dtype,
use_3xtf32,
split_k_slices,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
Expand Down Expand Up @@ -269,6 +270,7 @@ def handle_conv2d(
weight_dtype,
use_3xtf32,
conv_kind,
split_k_slices,
profile_all_alignments,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,
Expand All @@ -288,6 +290,7 @@ def tune_cutlass_kernels(
mod,
sm,
use_3xtf32=True,
split_k_slices=[1],
profile_all_alignments=False,
find_first_valid=False,
use_multiprocessing=False,
Expand All @@ -309,6 +312,14 @@ def tune_cutlass_kernels(
Wheter or not use slower but very accurate (compared to tf32) 3xtf32 mode for
fp32 inputs on tensorcore.
split_k_slices : list of int
Split factor candidates for split-K GEMM. If split-K > 1, the GEMM K-loop is computed in
parallel accross split-K blocks, and a seperate global reduction kernel is launched to
accumulate partial reductions. The profiler will pick the best split-k factor from the
given candidate list. Note that the larger split-K factor requires a larger workspace.
Currently, parallel split-k has been tested only for wgrad. For GEMM and other conv2d
kinds, split_k_slices is ignored.
profile_all_alignments : bool
When True, profile all kernal variants with smaller alignments than the largest possible.
Expand Down Expand Up @@ -380,6 +391,7 @@ def tune_cutlass_kernels(
arg0_dtype,
arg1_dtype,
use_3xtf32,
split_k_slices,
profile_all_alignments,
find_first_valid,
use_multiprocessing,
Expand Down
82 changes: 76 additions & 6 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
stride_support,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity1,
split_k_slices=1,
):
self.operation_kind = OperationKind.Conv2d
self.arch = arch
Expand All @@ -48,6 +49,7 @@ def __init__(
self.iterator_algorithm = iterator_algorithm
self.stride_support = stride_support
self.swizzling_functor = swizzling_functor
self.split_k_slices = split_k_slices

def accumulator_type(self):
return self.tile_description.math_instruction.element_accumulator
Expand Down Expand Up @@ -127,6 +129,9 @@ def procedural_name(self):
"_${layout}_align${alignment}"
)

if self.split_k_slices > 1:
configuration_name += "_splitk%d" % self.split_k_slices

return substitute_template(
configuration_name,
{
Expand Down Expand Up @@ -172,6 +177,14 @@ def __init__(self):
${unary_op}
>"""

self.epilogue_wgrad = """
${epilogue_functor}<
${element_c},
4,
float,
float
>"""

self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name} =
Expand All @@ -197,9 +210,31 @@ def __init__(self):
${align_a},
${align_b}
>::Kernel;
${reduction}
"""

self.reduction_template = """
using EpilogueOutputOp = ${epilogue};
using ReductionOp = cutlass::reduction::thread::ReduceAdd<
${element_accumulator},
${element_accumulator},
EpilogueOutputOp::kCount
>;
using ReductionKernel = cutlass::reduction::kernel::ReduceSplitK<
cutlass::MatrixShape<4, 32 * EpilogueOutputOp::kCount>,
EpilogueOutputOp,
ReductionOp
>;
using ReductionDevice = cutlass::reduction::device::ReduceSplitK<ReductionKernel>;
using ReductionStrideIndex = typename ReductionDevice::StrideIndex;
"""

def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
def emit(
self, operation, no_beta_scaling=False, residual_block_info=False, emit_reduction=False
):
"""Instantiate a Conv2d kernel from given `operation`."""
warp_shape = [
int(
Expand All @@ -214,6 +249,31 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
/ DataTypeSize[operation.C.element]
)

element_c = operation.C.element
use_split_k_wgrad = operation.conv_kind == ConvKind.Wgrad and operation.split_k_slices > 1
# Gemm output always fp32 in wgrad with split k
element_c_gemm = DataType.f32 if use_split_k_wgrad else element_c

if emit_reduction:
epilogue_reduction = substitute_template(
self.epilogue_wgrad,
{
"epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
"element_c": DataTypeTag[element_c],
},
)
reduction = substitute_template(
self.reduction_template,
{
"epilogue": epilogue_reduction,
"operation_name": operation.procedural_name(),
"element_accumulator": DataTypeTag[operation.accumulator_type()],
},
)
gemm_template = substitute_template(self.template, {"reduction": reduction})
else:
gemm_template = substitute_template(self.template, {"reduction": ""})

values = {
"operation_name": operation.procedural_name(),
"conv_kind": ConvKindTag[operation.conv_kind],
Expand All @@ -222,7 +282,7 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
"layout_a": LayoutTag[operation.A.layout],
"element_b": DataTypeTag[operation.B.element],
"layout_b": LayoutTag[operation.B.layout],
"element_c": DataTypeTag[operation.C.element],
"element_c": DataTypeTag[element_c_gemm],
"layout_c": LayoutTag[operation.C.layout],
"element_accumulator": DataTypeTag[operation.accumulator_type()],
"opcode_class": OpcodeClassTag[
Expand Down Expand Up @@ -262,9 +322,19 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
"conv_kernel_postfix": "",
}

if residual_block_info:
if use_split_k_wgrad:
# Even if the output is fp16, gemm output is always fp32 for split k wgrad.
epilogue_gemm = substitute_template(
self.epilogue_wgrad,
{
"epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
"element_c": "float",
},
)
template = substitute_template(gemm_template, {"epilogue": epilogue_gemm})
elif residual_block_info:
template = substitute_template(
self.template, {"epilogue": self.epilogue_residual_block}
gemm_template, {"epilogue": self.epilogue_residual_block}
)
values.update(
{
Expand All @@ -276,9 +346,9 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
)
elif no_beta_scaling:
template = substitute_template(
self.template, {"epilogue": self.epilogue_no_beta_scaling}
gemm_template, {"epilogue": self.epilogue_no_beta_scaling}
)
else:
template = substitute_template(self.template, {"epilogue": self.epilogue_default})
template = substitute_template(gemm_template, {"epilogue": self.epilogue_default})

return substitute_template(template, values)
58 changes: 51 additions & 7 deletions python/tvm/contrib/cutlass/conv2d_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,41 @@
# pylint: disable=import-outside-toplevel, invalid-name
"""Instantiate a C++ source for profiling CUTLASS kernels."""

from .library import DataTypeTag


class Conv2dProfilerEmitter(object):
"""Emit a C++ source for profiling CUTLASS kernels."""

def __init__(self):
from jinja2 import Template

self.reduction = """
ReductionDevice reduction_op;
static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemm::kConvolutionalOperator;
typename ReductionDevice::Arguments reduction_args(
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(),
problem_size.split_k_slices,
cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size),
{
reinterpret_cast<ImplicitGemm::ElementC*> (workspace.get()),
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
},
{
tensor_d.device_data(),
ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
},
{
tensor_c.device_data(),
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
},
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)}
);
reduction_op.initialize(reduction_args, nullptr);
reduction_op();
"""

self.template = Template(
"""
#include <iostream>
Expand All @@ -35,6 +63,8 @@ def __init__(self):
#include "cutlass/util/command_line.h"
#include "cutlass/util/host_tensor.h"
#include "cutlass/util/reference/host/tensor_fill.h"
#include "cutlass/reduction/device/reduce_split_k.h"
#include "cutlass/reduction/thread/reduction_operators.h"
#define CUTLASS_CHECK(status) \
{ \
Expand Down Expand Up @@ -88,10 +118,11 @@ def __init__(self):
};
double profile_convolution(Options const &options) {
using ElementOutput = typename ImplicitGemm::ElementC;
using ElementOutput = {{ElementOutput}};
using ElementInputA = typename ImplicitGemm::ElementA;
using ElementInputB = typename ImplicitGemm::ElementB;
int split_k_slices = {{SplitK}};
cutlass::conv::Conv2dProblemSize problem_size(
options.input_size,
options.filter_size,
Expand All @@ -100,28 +131,34 @@ def __init__(self):
options.dilation,
options.output_size(),
cutlass::conv::Mode::kCrossCorrelation,
1
split_k_slices
);
auto conv_kind = ImplicitGemm::kConvolutionalOperator;
auto a_extent = implicit_gemm_tensor_a_extent(conv_kind, problem_size);
auto b_extent = implicit_gemm_tensor_b_extent(conv_kind, problem_size);
auto c_extent = implicit_gemm_tensor_c_extent(conv_kind, problem_size);
using LayoutC = typename ImplicitGemm::LayoutC;
cutlass::HostTensor<ElementInputA, typename ImplicitGemm::LayoutA> tensor_a(a_extent);
cutlass::HostTensor<ElementInputB, typename ImplicitGemm::LayoutB> tensor_b(b_extent);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_c(c_extent);
cutlass::HostTensor<ElementOutput, typename ImplicitGemm::LayoutC> tensor_ref_c(c_extent);
cutlass::HostTensor<ElementOutput, LayoutC> tensor_d(c_extent);
cutlass::HostTensor<ImplicitGemm::ElementC, LayoutC> tensor_c_gemm(c_extent);
using ElementComputeEpilogue = typename ImplicitGemm::ElementCompute;
cutlass::conv::SplitKMode const split_k_mode = split_k_slices > 1 ?
cutlass::conv::SplitKMode::kParallel : cutlass::conv::SplitKMode::kSerial;
typename ImplicitGemm::Arguments arguments{
problem_size,
tensor_a.device_ref(),
tensor_b.device_ref(),
tensor_c.device_ref(),
tensor_c.device_ref(),
tensor_c_gemm.device_ref(),
tensor_c_gemm.device_ref(),
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)},
split_k_mode,
};
ImplicitGemm implicit_gemm_op;
Expand All @@ -144,6 +181,7 @@ def __init__(self):
for (int iteration = 0; iteration < 100; ++iteration) {
auto status = implicit_gemm_op();
CUTLASS_CHECK(status);
{{Reduction}}
}
cudaEventRecord(events[1]);
Expand All @@ -166,6 +204,12 @@ def __init__(self):
"""
)

def emit(self, op_def, op_name):
src = self.template.render(OperatorDef=op_def, OperatorName=op_name)
def emit(self, op_def, op_name, element_output, split_k_slices=1):
src = self.template.render(
OperatorDef=op_def,
OperatorName=op_name,
ElementOutput=DataTypeTag[element_output],
SplitK=split_k_slices,
Reduction=self.reduction if split_k_slices > 1 else "",
)
return src
Loading

0 comments on commit 41135aa

Please sign in to comment.