Skip to content

Commit

Permalink
Support residual block fusion
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Dec 23, 2021
1 parent ce9d52f commit fda151b
Show file tree
Hide file tree
Showing 8 changed files with 283 additions and 38 deletions.
45 changes: 35 additions & 10 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __init__(self):
${element_accumulator},
${element_epilogue}
>"""

self.epilogue_no_beta_scaling = """
${epilogue_functor}<
${element_c},
Expand All @@ -159,10 +160,22 @@ def __init__(self):
cutlass::epilogue::thread::ScaleType::NoBetaScaling
>"""

self.epilogue_residual_block = """
${epilogue_functor}<
${element_c},
${element_accumulator},
${element_epilogue},
${element_c},
${epilogue_vector_length},
${activation},
${binary_op},
${unary_op}
>"""

self.template = """
// Conv2d${conv_kind_name} ${iterator_algorithm_name} kernel instance "${operation_name}"
using ${operation_name} =
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}${conv_kernel_postfix}<
${element_a},
${layout_a},
${element_b},
Expand All @@ -186,7 +199,7 @@ def __init__(self):
>::Kernel;
"""

def emit(self, operation, no_beta_scaling=False):
def emit(self, operation, no_beta_scaling=False, residual_block_info=False):
"""Instantiate a Conv2d kernel from given `operation`."""
warp_shape = [
int(
Expand Down Expand Up @@ -246,14 +259,26 @@ def emit(self, operation, no_beta_scaling=False):
],
"align_a": str(operation.A.alignment),
"align_b": str(operation.B.alignment),
"conv_kernel_postfix": "",
}

template = substitute_template(
self.template,
{
"epilogue": self.epilogue_no_beta_scaling
if no_beta_scaling
else self.epilogue_default
},
)
if residual_block_info:
template = substitute_template(
self.template, {"epilogue": self.epilogue_residual_block}
)
values.update(
{
"unary_op": residual_block_info["unary_op"],
"binary_op": residual_block_info["binary_op"],
"activation": residual_block_info["activation"],
"conv_kernel_postfix": "WithBroadcast",
}
)
elif no_beta_scaling:
template = substitute_template(
self.template, {"epilogue": self.epilogue_no_beta_scaling}
)
else:
template = substitute_template(self.template, {"epilogue": self.epilogue_default})

return substitute_template(template, values)
31 changes: 29 additions & 2 deletions python/tvm/contrib/cutlass/gen_conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,32 @@ def create_conv2d_operator_with_epilogue(
Instantiate a cutlass kernel from the given configuration,
along with the epilouge functor
"""
epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]
if "residual" in op_type:
activation_map = {
"cutlass.conv2d_bias_hardswish": "cutlass::epilogue::thread::HardSwish",
"cutlass.conv2d_bias_silu": "cutlass::epilogue::thread::SiLu",
"cutlass.conv2d_bias_sigmoid": "cutlass::epilogue::thread::Sigmoid",
"cutlass.conv2d_bias_relu": "cutlass::epilogue::thread::ReLu",
"cutlass.conv2d_bias": "cutlass::epilogue::thread::Identity",
}
prefix = op_type[: op_type.find("_residual")]
activation = activation_map[prefix]
binary_op = "cutlass::multiplies" if "residual_multiply" in op_type else "cutlass::plus"
unary_op = (
"cutlass::epilogue::thread::ReLu"
if op_type.endswith("relu")
else "cutlass::epilogue::thread::Identity"
)
residual_block_info = {
"activation": activation,
"binary_op": binary_op,
"unary_op": unary_op,
}
epilogue = EpilogueFunctor.LinearCombinationResidualBlock
no_beta_scaling = False
else:
residual_block_info = None
epilogue, no_beta_scaling = EPILOGUE_MAP[op_type]

element_a, element_b, element_c, element_epilogue = data_type

Expand All @@ -62,7 +87,9 @@ def create_conv2d_operator_with_epilogue(
)

name = op.procedural_name()
opdef = EmitConv2dInstance().emit(op, no_beta_scaling=no_beta_scaling)
opdef = EmitConv2dInstance().emit(
op, no_beta_scaling=no_beta_scaling, residual_block_info=residual_block_info
)

return name, opdef

Expand Down
15 changes: 15 additions & 0 deletions python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,21 @@ def get_tile_descriptions(math_inst):
80: generate_sm80_tensor_op_16816,
}

EPILOGUE_MAP = {
"cutlass.dense": (EpilogueFunctor.LinearCombination, True),
"cutlass.dense_bias": (EpilogueFunctor.LinearCombinationBias, True),
"cutlass.dense_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True),
"cutlass.dense_bias_gelu_fp16": (EpilogueFunctor.LinearCombinationGelu, False),
"cutlass.dense_bias_gelu_fp32": (EpilogueFunctor.LinearCombinationGelu, False),
"cutlass.batch_matmul": (EpilogueFunctor.LinearCombination, True),
"cutlass.conv2d_bias_hardswish": (EpilogueFunctor.LinearCombinationHardSwish, False),
"cutlass.conv2d_bias_silu": (EpilogueFunctor.LinearCombinationSilu, False),
"cutlass.conv2d_bias_sigmoid": (EpilogueFunctor.LinearCombinationSigmoid, False),
"cutlass.conv2d_bias_relu": (EpilogueFunctor.LinearCombinationRelu, True),
"cutlass.conv2d_bias": (EpilogueFunctor.LinearCombinationBias, True),
"cutlass.conv2d": (EpilogueFunctor.LinearCombination, True),
}


# (Epilogue functor name, no_beta_scaling)
EPILOGUE_MAP = {
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/contrib/cutlass/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class EpilogueFunctor(enum.Enum):
LinearCombinationSigmoid = enum_auto()
LinearCombinationSilu = enum_auto()
LinearCombinationHardSwish = enum_auto()
LinearCombinationResidualBlock = enum_auto()


EpilogueFunctorTag = {
Expand All @@ -161,6 +162,7 @@ class EpilogueFunctor(enum.Enum):
EpilogueFunctor.LinearCombinationSigmoid: "cutlass::epilogue::thread::LinearCombinationSigmoid",
EpilogueFunctor.LinearCombinationSilu: "cutlass::epilogue::thread::LinearCombinationSilu",
EpilogueFunctor.LinearCombinationHardSwish: "cutlass::epilogue::thread::LinearCombinationHardSwish",
EpilogueFunctor.LinearCombinationResidualBlock: "cutlass::epilogue::thread::LinearCombinationResidualBlock",
}


Expand Down
59 changes: 48 additions & 11 deletions python/tvm/relay/op/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=invalid-name
"""Patterns supported CUTLASS."""
from functools import partial
from tvm import relay
from tvm.ir.transform import Sequential, PassContext
from tvm.relay import transform
Expand Down Expand Up @@ -89,6 +90,19 @@ def make_conv2d_pattern(with_bias=False, with_act=None):
return conv2d_out


def make_residual_block_pattern(tensor_op_out, binary_op="add", with_act="relu"):
"""Add pattern for residual blocks."""
residual_input = wildcard()
binary_out = is_op(binary_op)(tensor_op_out, residual_input) | is_op(binary_op)(
residual_input, tensor_op_out
)

if with_act is not None and with_act == "relu":
return is_op("nn.relu")(binary_out)

return binary_out


def check_dtype(lhs, rhs):
"""Check if dtypes in the given workload are supported by CUTLASS."""
# Only fp16 inputs are supported for now.
Expand Down Expand Up @@ -139,6 +153,25 @@ def check_conv2d(call):
return not is_depthwise_conv2d(IC, OC, conv2d.attrs.groups)


def check_conv2d_residual(call, binary_op):
"""Check if the given conv2d workload can be offloaded to CUTLASS."""
conv2d = get_root_call(call, "nn.conv2d")
if not check_conv2d(call):
return False

residual_binop = get_root_call(call, binary_op)
lhs = residual_binop.args[0]
rhs = residual_binop.args[1]

# residual_input is pattern-matched as a wildcard. Make sure it does not sit between
# residual binary op and the root conv2d of this pattern.
# If the root conv2d is the parent of both lhs and rhs, we should reject this pattern.
if get_root_call(lhs, "nn.conv2d") == conv2d and get_root_call(rhs, "nn.conv2d") == conv2d:
return True

return all([x == y for (x, y) in zip(lhs.checked_type.shape, rhs.checked_type.shape)])


def partition_for_cutlass(mod, params=None):
"""Partition the input module into CUTLASS-supported subgraphs."""
dense_pat = ("cutlass.dense", make_gemm_pattern(False, None), check_gemm)
Expand All @@ -165,16 +198,6 @@ def partition_for_cutlass(mod, params=None):
]

conv2d_patterns = [
(
"cutlass.conv2d_bias_hardswish",
make_conv2d_pattern(with_bias=True, with_act="hardswish"),
check_conv2d,
),
(
"cutlass.conv2d_bias_silu",
make_conv2d_pattern(with_bias=True, with_act="silu"),
check_conv2d,
),
(
"cutlass.conv2d_bias_hardswish",
make_conv2d_pattern(with_bias=True, with_act="hardswish"),
Expand All @@ -199,7 +222,20 @@ def partition_for_cutlass(mod, params=None):
("cutlass.conv2d", make_conv2d_pattern(), check_conv2d),
]

cutlass_patterns = dense_patterns + conv2d_patterns
residual_block_patterns = []

for with_act, postfix in [("relu", "_relu"), (None, "")]:
for name, pat, _ in conv2d_patterns[:-1]:
for bin_op in ["add", "multiply"]:
residual_block_patterns.append(
(
name + "_residual_" + bin_op + postfix,
make_residual_block_pattern(pat, bin_op, with_act=with_act),
partial(check_conv2d_residual, binary_op=bin_op),
)
)

cutlass_patterns = residual_block_patterns + dense_patterns + conv2d_patterns

if params is not None:
mod["main"] = bind_params_by_name(mod["main"], params)
Expand All @@ -217,6 +253,7 @@ def partition_for_cutlass(mod, params=None):
seq = Sequential(
[
transform.InferType(),
transform.SimplifyExpr(),
transform.MergeComposite(cutlass_patterns),
transform.AnnotateTarget(["cutlass"], include_non_call_ops=False),
transform.PartitionGraph(bind_constants=False),
Expand Down
Loading

0 comments on commit fda151b

Please sign in to comment.