Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 4, 2022
1 parent 0ece49b commit 6206e38
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 38 deletions.
69 changes: 62 additions & 7 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
stride_support,
epilogue_functor=EpilogueFunctor.LinearCombination,
swizzling_functor=SwizzlingFunctor.Identity1,
split_k_slices=1,
):
self.operation_kind = OperationKind.Conv2d
self.arch = arch
Expand All @@ -48,6 +49,7 @@ def __init__(
self.iterator_algorithm = iterator_algorithm
self.stride_support = stride_support
self.swizzling_functor = swizzling_functor
self.split_k_slices = split_k_slices

def accumulator_type(self):
return self.tile_description.math_instruction.element_accumulator
Expand Down Expand Up @@ -172,6 +174,22 @@ def __init__(self):
${unary_op}
>"""

self.epilogue_wgrad_split_k = """
${epilogue_functor}<
${element_c},
4,
float,
float,
>"""

self.epilogue_wgrad_split_k_tmp = """
${epilogue_functor}<
float,
4,
float,
float,
>"""

self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name} =
Expand All @@ -197,12 +215,14 @@ def __init__(self):
${align_a},
${align_b}
>::Kernel;
${reduction}
"""
self.reduction_template =
"""
self.reduction_template = """
using EpilogueOutputOp = ${epilogue};
using ReductionOp = cutlass::reduction::thread::ReduceAdd<
ElementAccumulator,
typename EpilogueOutputOp::ElementAccumulator,
${element_accumulator},
${element_accumulator},
EpilogueOutputOp::kCount
>;
Expand All @@ -216,7 +236,7 @@ def __init__(self):
using ReductionStrideIndex = typename ReductionDevice::StrideIndex;
"""

def emit(self, operation, no_beta_scaling=False, residual_block_info=False, split_k_slices=1):
def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
"""Instantiate a Conv2d kernel from given `operation`."""
warp_shape = [
int(
Expand All @@ -231,6 +251,39 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False, spli
/ DataTypeSize[operation.C.element]
)

element_c = operation.C.element
use_split_k = (
operation.split_k_slices > 1
and operation.conv_kind == ConvKind.Wgrad
and operation.C.element == DataType.f16
)

if use_split_k:
# split k
element_c = DataType.f32
epilogue_reduction = substitute_template(
self.epilogue_wgrad_split_k,
{
"epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor],
"element_c": DataTypeTag[element_c],
},
)
epilogue_gemm = substitute_template(
self.epilogue_wgrad_split_k_tmp,
{"epilogue_functor": EpilogueFunctorTag[operation.epilogue_functor]},
)
reduction = substitute_template(
self.reduction_template,
{
"epilogue": epilogue_reduction,
"operation_name": operation.procedural_name(),
"element_accumulator": DataTypeTag[operation.accumulator_type()],
},
)
self.template = substitute_template(self.template, {"reduction": reduction})
else:
self.template = substitute_template(self.template, {"reduction": ""})

values = {
"operation_name": operation.procedural_name(),
"conv_kind": ConvKindTag[operation.conv_kind],
Expand All @@ -239,7 +292,7 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False, spli
"layout_a": LayoutTag[operation.A.layout],
"element_b": DataTypeTag[operation.B.element],
"layout_b": LayoutTag[operation.B.layout],
"element_c": DataTypeTag[operation.C.element],
"element_c": DataTypeTag[element_c],
"layout_c": LayoutTag[operation.C.layout],
"element_accumulator": DataTypeTag[operation.accumulator_type()],
"opcode_class": OpcodeClassTag[
Expand Down Expand Up @@ -279,7 +332,9 @@ def emit(self, operation, no_beta_scaling=False, residual_block_info=False, spli
"conv_kernel_postfix": "",
}

if residual_block_info:
if use_split_k:
template = substitute_template(self.template, {"epilogue": epilogue_gemm})
elif residual_block_info:
template = substitute_template(
self.template, {"epilogue": self.epilogue_residual_block}
)
Expand Down
67 changes: 36 additions & 31 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def create_conv2d_operator_with_epilogue(
data_type,
alignment,
swizzling_functor,
split_k_slices=1,
):
"""
Instantiate a cutlass kernel from the given configuration,
Expand Down Expand Up @@ -90,6 +91,7 @@ def create_conv2d_operator_with_epilogue(
stride_support,
epilogue,
swizzling_functor,
split_k_slices,
)

name = op.procedural_name()
Expand All @@ -107,6 +109,7 @@ def enumerate_conv2d_operators(
data_type,
alignment_constraints,
swizzling_functor=SwizzlingFunctor.Identity4,
split_k_slices=[1],
):
"""Exhaustively instantiate all kernels from a given configuration."""
ret = []
Expand All @@ -119,37 +122,39 @@ def enumerate_conv2d_operators(
if conv_kind == ConvKind.Dgrad and stride_support == StrideSupport.Strided:
swizzling_functor = SwizzlingFunctor.StridedDgradIdentity1

for tile in tile_descriptions:
for alignment in alignment_constraints:

A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment)
B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment)
C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)

op = Conv2dOperation(
conv_kind,
IteratorAlgorithm.Optimized,
tile.minimum_compute_capability,
tile,
A,
B,
C,
element_epilogue,
stride_support,
EpilogueFunctor.LinearCombination,
swizzling_functor,
)

ret.append(
{
"src": profiler_emitter.emit(kernel_emitter.emit(op), op.procedural_name()),
"name": op.procedural_name(),
"tile_description": tile,
"alignment": alignment,
"data_type": data_type,
"swizzle_functor": swizzling_functor,
}
)
for split_k_slice in split_k_slices:
for tile in tile_descriptions:
for alignment in alignment_constraints:

A = TensorDescription(element_a, LayoutType.TensorNHWC, alignment)
B = TensorDescription(element_b, LayoutType.TensorNHWC, alignment)
C = TensorDescription(element_c, LayoutType.TensorNHWC, alignment)

op = Conv2dOperation(
conv_kind,
IteratorAlgorithm.Optimized,
tile.minimum_compute_capability,
tile,
A,
B,
C,
element_epilogue,
stride_support,
EpilogueFunctor.LinearCombination,
swizzling_functor,
split_k_slice
)

ret.append(
{
"src": profiler_emitter.emit(kernel_emitter.emit(op), op.procedural_name()),
"name": op.procedural_name(),
"tile_description": tile,
"alignment": alignment,
"data_type": data_type,
"swizzle_functor": swizzling_functor,
}
)

return ret

Expand Down

0 comments on commit 6206e38

Please sign in to comment.