Skip to content

Commit

Permalink
[BYOC] Update CUTLASS backend (SIMT support and codegen clean up) (ap…
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi authored Feb 20, 2023
1 parent 697c724 commit 5562d90
Show file tree
Hide file tree
Showing 13 changed files with 765 additions and 700 deletions.
2 changes: 1 addition & 1 deletion 3rdparty/cutlass
Submodule cutlass updated 1629 files
8 changes: 1 addition & 7 deletions gallery/how_to/work_with_relay/using_pipeline_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,7 @@
from tvm import relay
from tvm.relay import testing
import tvm.testing
from tvm.contrib.cutlass import (
has_cutlass,
num_cutlass_partitions,
finalize_modules,
finalize_modules_vm,
)
from tvm.contrib.cutlass import finalize_modules

img_size = 8
#######################################################################
Expand All @@ -50,7 +45,6 @@ def get_network():
"dweight", relay.TensorType((batch_size, 16 * img_size * img_size), "float16")
)
weight = relay.var("weight")
second_weight = relay.var("second_weight")
bn_gamma = relay.var("bn_gamma")
bn_beta = relay.var("bn_beta")
bn_mmean = relay.var("bn_mean")
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/contrib/cutlass/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""FFI API for CUTLASS BYOC."""
import tvm._ffi

tvm._ffi._init_api("contrib.cutlass", __name__)
65 changes: 48 additions & 17 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, dangerous-default-value
# pylint: disable=invalid-name, dangerous-default-value, arguments-differ
"""Driver for partitioning and building a Relay module for CUTLASS offload."""
import logging
import os
import multiprocessing
import tvm
from tvm import runtime, relay
from tvm.contrib.nvcc import get_cuda_version
from tvm import relay, runtime
from tvm._ffi.registry import register_func
from .gen_gemm import CutlassGemmProfiler
from tvm.contrib.nvcc import get_cuda_version

from .gen_conv2d import CutlassConv2DProfiler
from .gen_gemm import CutlassGemmProfiler
from .library import ConvKind

logger = logging.getLogger("cutlass")
Expand Down Expand Up @@ -93,7 +94,11 @@ def visit_call(self, call):
self.signature["ret_dtype"] = op.ret_type.dtype
self.visit(op.body)

if str(op) in ["nn.conv2d", "nn.conv2d_transpose", "nn.conv2d_backward_weight"]:
elif isinstance(op, tvm.ir.Op) and op.name in [
"nn.conv2d",
"nn.conv2d_transpose",
"nn.conv2d_backward_weight",
]:
self.op_attrs = call.attrs

for arg in call.args:
Expand Down Expand Up @@ -516,6 +521,42 @@ def tune_cutlass_function(
)


@register_func("contrib.cutlass.compile")
def compile_cutlass_module(c_source_module, options):
"""Compile all CUTLASS kernels in the given C-source module.
Parameters
----------
c_source_module: runtime.Module
A C-source module containing CUTLASS kernels.
options: dict
Compilation options. Currently recognizes
"sm": The target architecture (compute capability), for example 75 or 80 (default: 80)
"threads": The number of threads to use in NVCC parallel compilation (default:
use all logical cores)
"use_fast_math": Whether or not to use faster but approximate arithmetic in some
CUTLASS epilogues (default: False)
Returns
-------
rt_mod : runtime.Module
A runtime module where all cutlass kernels have been compiled.
"""
tmp_dir = options.get("tmp_dir", "./tmp")
defaults = {"sm": 80, "threads": -1, "use_fast_math": False}
compile_config = {key: options.get(key, val) for key, val in defaults.items()}

function_names = c_source_module.get_function("get_func_names")()
compile_options = _get_cutlass_compile_options(**compile_config)
lib_path = os.path.join(tmp_dir, "cutlass.o")
logger.info("Compiling generated CUTLASS code")
c_source_module.export_library(lib_path, workspace_dir=tmp_dir, **compile_options)

# Recover static library
return tvm.runtime.load_static_library(lib_path, function_names)


@register_func("relay.ext.cutlass.compile_for_cutlass")
def compile_for_cutlass(mod, cutlass_target):
"""Given an IRModule with at least one Compiler='cutlass' Relay function, return a
Expand Down Expand Up @@ -549,6 +590,7 @@ def compile_for_cutlass(mod, cutlass_target):
key: cutlass_target.attrs.get(key) for key in ["sm", "threads", "use_fast_math"]
}
tmp_dir = cutlass_target.attrs.get("tmp_dir")
compile_config["tmp_dir"] = tmp_dir

# Tune
logger.info("Tuning for CUTLASS")
Expand All @@ -558,18 +600,7 @@ def compile_for_cutlass(mod, cutlass_target):
logger.info("Creating CSource module for CUTLASS")
create_c_source_module = tvm._ffi.get_global_func("relay.ext.cutlass.create_c_source_module")
c_module = create_c_source_module(mod)
function_names = c_module.get_function("get_func_names")()
compile_options = _get_cutlass_compile_options(**compile_config)
lib_path = os.path.join(tmp_dir, "cutlass.o")
logger.info("Compiling generated CUTLASS code")
c_module.export_library(lib_path, workspace_dir=tmp_dir, **compile_options)

# Recover static library
logger.info("Loading compiled CUTLASS code")
final_mod = tvm.runtime.load_static_library(lib_path, function_names)

logger.info("Done with CUTLASS compilation")
return final_mod
return compile_cutlass_module(c_module, compile_config)


def finalize_modules(lib, lib_path="compile.so", tmp_dir="./tmp"):
Expand Down
186 changes: 186 additions & 0 deletions python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,3 +352,189 @@ def emit(
template = substitute_template(gemm_template, {"epilogue": self.epilogue_default})

return substitute_template(template, values)


def instantiate_conv2d_template(attrs, func_args):
"""Return CUTLASS host code for conv2d based on a template and the provided attribute map."""
template = """
${cutlass_op_def}
using Conv2d = cutlass::conv::device::ImplicitGemmConvolution<${cutlass_op_name}>;
using ElementInputA = Conv2d::ElementA;
using ElementInputB = Conv2d::ElementB;
using ElementComputeEpilogue = Conv2d::ElementAccumulator;
int N = ${N};
int H = ${H};
int W = ${W};
int C = ${C};
int K = ${K};
int R = ${R};
int S = ${S};
int P = ${P};
int Q = ${Q};
int pad_h = ${pad_h};
int pad_w = ${pad_w};
int stride_h = ${stride_h};
int stride_w = ${stride_w};
int dilation_h = ${dilation_h};
int dilation_w = ${dilation_w};
int split_k_slices = ${split_k_slices};
cutlass::conv::Conv2dProblemSize problem_size(N, H, W, C, K, R, S, P, Q, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, cutlass::conv::Mode::kCrossCorrelation, split_k_slices);
const cutlass::conv::SplitKMode split_k_mode = cutlass::conv::SplitKMode::${split_k_mode};
void* ptr_a = (void*)(${arg0}->data);
void* ptr_b = (void*)(${arg1}->data);
${bias_decl}
${residual_decl}
void* ptr_out = (void*)(out0->data);
ElementComputeEpilogue alpha = ElementComputeEpilogue(1);
ElementComputeEpilogue beta = ElementComputeEpilogue(${beta});
using cutlass::layout::TensorNHWC;
auto activation_shape = TensorNHWC::packed(cutlass::make_Coord(N, H, W, C));
auto weight_shape = TensorNHWC::packed(cutlass::make_Coord(K, R, S, C));
auto output_shape = TensorNHWC::packed(cutlass::make_Coord(N, P, Q, K));
TensorNHWC layout_A(${A_shape});
TensorNHWC layout_B(${B_shape});
TensorNHWC layout_C(${C_shape});
TensorNHWC layout_D(${C_shape});
using ElementOutput = ${ElementOutput};
cutlass::TensorRef<ElementOutput, TensorNHWC> tensor_c{static_cast<ElementOutput*>(${tensor_c}), ${tensor_c_layout}};
cutlass::TensorRef<ElementOutput, TensorNHWC> tensor_d{static_cast<ElementOutput*>(ptr_out), layout_D};
typename Conv2d::Arguments arguments{
problem_size,
{static_cast<ElementInputA*>(ptr_a), layout_A},
{static_cast<ElementInputB*>(ptr_b), layout_B},
${tensor_c_arg},
${tensor_d_arg},
{${alpha_beta}},
split_k_mode
${additional_args}
};
Conv2d conv2d_op;
size_t workspace_size = conv2d_op.get_workspace_size(arguments);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
cutlass::Status status = conv2d_op.can_implement(arguments);
CHECK(status == cutlass::Status::kSuccess);
${split_k_reset}
status = conv2d_op.initialize(arguments, workspace.get());
CHECK(status == cutlass::Status::kSuccess);
${split_k_update}
status = conv2d_op();
CHECK(status == cutlass::Status::kSuccess);
${split_k_reduction}
"""

split_k_reset = """
arguments.ref_D.reset(reinterpret_cast<ElementComputeEpilogue*>(workspace.get()), layout_D);
"""

split_k_update = """
arguments.output_op = {ElementComputeEpilogue(1), ElementComputeEpilogue(0)};
status = conv2d_op.update(arguments, workspace.get());
CHECK(status == cutlass::Status::kSuccess);
"""

split_k_reduction = """
ReductionDevice reduction_op;
const static cutlass::conv::Operator kConvolutionalOperator = Conv2d::kConvolutionalOperator;
typename ReductionDevice::Arguments reduction_args(
cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, problem_size).mn(),
problem_size.split_k_slices,
cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size),
{
reinterpret_cast<Conv2d::ElementAccumulator*> (workspace.get()),
ReductionStrideIndex(tensor_c.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
},
{
tensor_d.data(),
ReductionStrideIndex(tensor_d.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
},
{
tensor_c.data(),
ReductionStrideIndex(tensor_c.stride()[Conv2d::UnderlyingKernel::kTensorCStrideIdx])
},
{alpha, beta}
);
status = reduction_op.initialize(reduction_args, nullptr);
status = reduction_op();
"""
op_type = attrs["op_type"]
has_bias = "bias" in op_type
use_split_k = "splitk" in attrs["cutlass_op_name"]
is_wgrad = "backward_weight" in op_type
is_dgrad = "conv2d_transpose" in op_type
has_residual_blcok = "residual" in op_type
no_bias_scaling = op_type not in [
"cutlass.conv2d_bias_sigmoid",
"cutlass.conv2d_bias_silu",
"cutlass.conv2d_bias_hardswish",
]

aux_map = {}

if (not has_bias or no_bias_scaling) and not has_residual_blcok:
aux_map["beta"] = "0"
else:
aux_map["beta"] = "1"

if has_residual_blcok:
aux_map["bias_decl"] = "void* ptr_bias = (void*)(${arg2}->data);\n"
aux_map["residual_decl"] = "void* ptr_residual = (void*)(${arg3}->data);"
aux_map["tensor_c"] = "ptr_residual"
aux_map["tensor_c_layout"] = "layout_C"
elif has_bias:
aux_map["bias_decl"] = "void* ptr_c_bias = (void*)(${arg2}->data);\n"
aux_map["residual_decl"] = ""
aux_map["tensor_c"] = "ptr_c_bias"
aux_map["tensor_c_layout"] = "cutlass::layout::TensorNHWC::Stride(0)"
else:
aux_map["bias_decl"] = ""
aux_map["residual_decl"] = ""
aux_map["tensor_c"] = "ptr_out"
aux_map["tensor_c_layout"] = "layout_C"

if has_bias and no_bias_scaling and not has_residual_blcok:
aux_map["alpha_beta"] = "alpha"
else:
aux_map["alpha_beta"] = "alpha, beta"

if has_residual_blcok:
aux_map["additional_args"] = ", static_cast<ElementOutput*>(ptr_bias), nullptr, 0, K"
else:
aux_map["additional_args"] = ""

if is_wgrad:
aux_map["A_shape"] = "output_shape"
aux_map["B_shape"] = "activation_shape"
aux_map["C_shape"] = "weight_shape"
elif is_dgrad:
aux_map["A_shape"] = "output_shape"
aux_map["B_shape"] = "weight_shape"
aux_map["C_shape"] = "activation_shape"
else:
aux_map["A_shape"] = "activation_shape"
aux_map["B_shape"] = "weight_shape"
aux_map["C_shape"] = "output_shape"

if use_split_k:
aux_map["ElementOutput"] = "EpilogueOutputOp::ElementOutput"
aux_map["tensor_c_arg"] = "{nullptr, TensorNHWC()}"
aux_map["tensor_d_arg"] = "{nullptr, TensorNHWC()}"
aux_map["split_k_reset"] = split_k_reset
aux_map["split_k_update"] = split_k_update
aux_map["split_k_reduction"] = split_k_reduction
else:
aux_map["ElementOutput"] = "Conv2d::ElementC"
aux_map["tensor_c_arg"] = "tensor_c"
aux_map["tensor_d_arg"] = "tensor_d"
aux_map["split_k_reset"] = aux_map["split_k_update"] = aux_map["split_k_reduction"] = ""

template = substitute_template(template, aux_map)

for i, arg in enumerate(func_args):
attrs["arg{}".format(i)] = arg

return substitute_template(template, attrs)
6 changes: 3 additions & 3 deletions python/tvm/contrib/cutlass/conv2d_profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,15 @@ def __init__(self):
cutlass::conv::implicit_gemm_tensor_c_size(kConvolutionalOperator, problem_size),
{
reinterpret_cast<ImplicitGemm::ElementC*> (workspace.get()),
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
},
{
tensor_d.device_data(),
ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
},
{
tensor_c.device_data(),
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx])
ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx])
},
{ElementComputeEpilogue(1), ElementComputeEpilogue(0)}
);
Expand Down
Loading

0 comments on commit 5562d90

Please sign in to comment.