Skip to content

Commit

Permalink
[CUTLASS] Add wgrad support (without split-k)
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 4, 2022
1 parent 96416c4 commit 0a82149
Show file tree
Hide file tree
Showing 9 changed files with 222 additions and 3 deletions.
2 changes: 2 additions & 0 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,8 @@ def select_op(
lambda align: all([dim % align == 0 for dim in [IC, OC]]),
use_3xtf32,
profile_all_alignments,
# Use fp32 accumulation for wgrad to align with cuDNN
accumlator_dtype="float32" if conv_kind == ConvKind.Wgrad else out_dtype,
)

if not find_first_valid:
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/contrib/cutlass/gen_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ def get_default(
lambda align: align == 1, # Only request align1 kernels
use_3xtf32,
profile_all_alignments=True, # To include all align1 kernels
# TODO(masahi): Invesitigate when fp32 accumulation is needed for gemm
accumlator_dtype=out_dtype,
)

default_kernel_name = DEFAULT_KERNELS[self.sm][(arg0_dtype, out_dtype)]
Expand Down Expand Up @@ -220,6 +222,8 @@ def select_op(
lambda align: all([dim % align == 0 for dim in [M, N, K]]),
use_3xtf32,
profile_all_alignments=profile_all_alignments,
# TODO(masahi): Invesitigate when fp32 accumulation is needed for gemm
accumlator_dtype=out_dtype,
)

if not find_first_valid:
Expand Down Expand Up @@ -266,6 +270,7 @@ def profile(
profile_all_alignments=profile_all_alignments,
find_first_valid=find_first_valid,
use_multiprocessing=use_multiprocessing,

)

name, opdef = create_gemm_operator_with_epilogue(
Expand Down
12 changes: 10 additions & 2 deletions python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def generate_tensor_op_common(
data_type = [
math_inst.element_a,
math_inst.element_b,
math_inst.element_accumulator,
math_inst.element_c,
math_inst.element_accumulator,
]

Expand All @@ -63,7 +63,7 @@ def generate_tensor_op_common(


def generate_sm75_tensor_op_1688(
out_dtype, arg0_dtype, arg1_dtype, op_creator, check_align, _, profile_all_alignments=False
out_dtype, arg0_dtype, arg1_dtype, op_creator, check_align, _, profile_all_alignments=False, accumlator_dtype="float32",
):
"""Generate GEMM or Conv2D kernels for Turing."""
assert out_dtype in ["float32", "float16", "int32"]
Expand All @@ -77,6 +77,7 @@ def generate_sm75_tensor_op_1688(
DataType.f16,
DataType.f16,
dtype_map[out_dtype],
dtype_map[accumlator_dtype],
OpcodeClass.TensorOp,
MathOperation.multiply_add,
)
Expand All @@ -100,6 +101,7 @@ def generate_sm75_tensor_op_1688(
dtype_map[arg0_dtype],
dtype_map[arg1_dtype],
DataType.s32,
DataType.s32,
OpcodeClass.TensorOp,
MathOperation.multiply_add_saturate,
),
Expand Down Expand Up @@ -141,6 +143,7 @@ def generate_sm80_tensor_op_16816(
check_align,
use_3xtf32=True,
profile_all_alignments=False,
accumlator_dtype="float32",
):
"""Generate GEMM or Conv2D kernels for Ampere."""
min_cc = 80
Expand Down Expand Up @@ -176,6 +179,7 @@ def get_default_tile_descriptions(block_k_factor):
DataType.f16,
DataType.f16,
dtype_map[out_dtype],
dtype_map[accumlator_dtype],
OpcodeClass.TensorOp,
MathOperation.multiply_add,
)
Expand All @@ -189,6 +193,7 @@ def get_default_tile_descriptions(block_k_factor):
DataType.f32,
DataType.f32,
DataType.f32,
DataType.f32,
OpcodeClass.TensorOp,
MathOperation.multiply_add_fast_f32 if use_3xtf32 else MathOperation.multiply_add,
),
Expand Down Expand Up @@ -221,6 +226,7 @@ def get_default_tile_descriptions(block_k_factor):
dtype_map[arg0_dtype],
dtype_map[arg1_dtype],
DataType.s32,
DataType.s32,
OpcodeClass.TensorOp,
MathOperation.multiply_add_saturate,
),
Expand Down Expand Up @@ -248,6 +254,7 @@ def get_tile_descriptions(math_inst):
check_align,
False,
profile_all_alignments,
accumlator_dtype=accumlator_dtype,
)
else:
# TF32 (float32 + float32 case) is only supported on sm80
Expand Down Expand Up @@ -292,6 +299,7 @@ def get_tile_descriptions(math_inst):
"cutlass.conv2d_bias": (EpilogueFunctor.LinearCombinationBias, True),
"cutlass.conv2d": (EpilogueFunctor.LinearCombination, False),
"cutlass.conv2d_transpose": (EpilogueFunctor.LinearCombination, False),
"cutlass.conv2d_backward_weight": (EpilogueFunctor.LinearCombination, False),
}


Expand Down
2 changes: 2 additions & 0 deletions python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,15 @@ def __init__(
instruction_shape,
element_a,
element_b,
element_c,
element_accumulator,
opcode_class,
math_operation=MathOperation.multiply_add,
):
self.instruction_shape = instruction_shape
self.element_a = element_a
self.element_b = element_b
self.element_c = element_c
self.element_accumulator = element_accumulator
self.opcode_class = opcode_class
self.math_operation = math_operation
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def make_conv2d_transpose_pattern():
return is_op("nn.conv2d_transpose")(wildcard(), wildcard())


def make_conv2d_backward_weight_pattern():
return is_op("nn.conv2d_backward_weight")(wildcard(), wildcard())


def make_residual_block_pattern(tensor_op_out, binary_op="add", with_act="relu"):
"""Add pattern for residual blocks."""
residual_input = wildcard()
Expand Down Expand Up @@ -173,6 +177,10 @@ def check_conv2d_transpose(call):
return check_conv2d_common("nn.conv2d_transpose", "IHWO", call)


def check_conv2d_backward_weight(call):
return check_conv2d_common("nn.conv2d_backward_weight", "NHWC", call)


def check_conv2d_residual(call, binary_op):
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
conv2d = get_root_call(call, "nn.conv2d")
Expand Down Expand Up @@ -245,6 +253,11 @@ def partition_for_cutlass(mod, params=None):
# For now, no fusion for grad kernels
conv2d_grad_patterns = [
("cutlass.conv2d_transpose", make_conv2d_transpose_pattern(), check_conv2d_transpose),
(
"cutlass.conv2d_backward_weight",
make_conv2d_backward_weight_pattern(),
check_conv2d_backward_weight,
),
]

residual_block_patterns = []
Expand Down
34 changes: 34 additions & 0 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1143,6 +1143,40 @@ def legalize_conv2d_backward_weight(attrs, inputs, types):
return backward_weight


@reg.register_convert_op_layout("nn.conv2d_backward_weight")
def convert_conv2d_backward_weight(attrs, inputs, _, desired_layouts):
"""Convert Layout pass registration for conv2d_backward_weight op.
Note that `desired_layouts` must be a pair [`data_layout`, `kernel_layouts`],
where `kernel_layouts` affects the output of this op (since the output of this op
is the weight gradient). The layout of the output gradient (the second input to this op)
is assumed to be the same as `data_layout`.
Parameters
----------
attrs : tvm.ir.Attrs
Attributes of current op
inputs : list of tvm.relay.Expr
The args of the Relay expr to be legalized
tinfos : list of types
List of input and output types
desired_layouts : list of layout strings
List of layouts defining our desired
layout for the data and kernel inputs respectively.
Returns
-------
result : tvm.relay.Expr
The transformed expr
"""
new_attrs = dict(attrs)
assert len(desired_layouts) == 2, "A desired layout is expected for both of data and gradient."
desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
assert desired_data_layout != "default", "Data layout cannot be default"
new_attrs["grad_layout"] = desired_data_layout
new_attrs["data_layout"] = desired_data_layout
new_attrs["kernel_layout"] = desired_kernel_layout
new_attrs.pop("out_layout")
return relay.nn.conv2d_backward_weight(inputs[0], inputs[1], **new_attrs)


#####################
# Shape functions #
#####################
Expand Down
5 changes: 4 additions & 1 deletion python/tvm/topi/cuda/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,9 @@ def conv2d_backward_weight_cudnn(
):
"""Compute conv2d wgrad using CuDNN library"""
assert layout in ["NCHW", "NHWC"]
# cuDNN does not seem to support other combination.
assert output_dtype == "float16", "Only supports fp16 output for cuDNN wgrad."
conv_dtype = "float32"
return cudnn.conv_backward_filter(
dy,
x,
Expand All @@ -139,6 +142,6 @@ def conv2d_backward_weight_cudnn(
dilation,
conv_mode=1,
tensor_format=0 if layout == "NCHW" else 1,
conv_dtype=output_dtype,
conv_dtype=conv_dtype,
groups=groups,
)
5 changes: 5 additions & 0 deletions src/relay/backend/contrib/cutlass/codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,11 @@ class CodegenCutlass : public MemoizedExprTranslator<std::vector<Output>>, publi
GetRootCall(callee->body.as<CallNode>(), 0, {"nn.conv2d_transpose"});
return GenerateBody(conv2d_call, "cutlass_conv2d_transpose", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_), true, false));
} else if (pattern_name == "cutlass.conv2d_backward_weight") {
const auto* conv2d_call =
GetRootCall(callee->body.as<CallNode>(), 0, {"nn.conv2d_backward_weight"});
return GenerateBody(conv2d_call, "cutlass_conv2d_backward_weight", GetArgumentNames(caller),
Conv2dArgs(std::ref(attrs_), false, true));
}

LOG(FATAL) << "Unknown composite function: " << pattern_name;
Expand Down
Loading

0 comments on commit 0a82149

Please sign in to comment.