Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry-pick key PRs from unity to support LLaMA/Vicuna optimization #254

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion python/tvm/contrib/cutlass/attention_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,20 @@ def instantiate_attention_template(attrs):
typename Attention::Params p;
p.logsumexp_ptr = nullptr;
p.output_ptr = reinterpret_cast<T *>(out0->data);

p.output_accum_ptr = nullptr;
uint64_t accumulator_buf_size = ${output_size} * sizeof(Attention::output_accum_t);
bool accumulator_buf_allocated = false;
if (Attention::kNeedsOutputAccumulatorBuffer) {
p.output_accum_ptr = static_cast<float*>(${workspace}->data);
if (accumulator_buf_size <= ${workspace}->shape[0]) {
p.output_accum_ptr = static_cast<float*>(${workspace}->data);
} else {
accumulator_buf_size = true;
cudaMalloc(
&p.output_accum_ptr,
accumulator_buf_size
);
}
}

p.num_heads = ${num_heads}; // N
Expand Down Expand Up @@ -129,6 +140,10 @@ def instantiate_attention_template(attrs):

CHECK(Attention::check_supported(p));
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);

if (accumulator_buf_allocated) {
cudaFree(p.output_accum_ptr);
}
"""

template = substitute_template(
Expand Down
1 change: 1 addition & 0 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def _get_cutlass_compile_options(sm, threads, use_fast_math=False):
"-Xcompiler=-fPIC",
"-Xcompiler=-Wconversion",
"-Xcompiler=-fno-strict-aliasing",
"-Xcompiler=-fvisibility=hidden",
"-O3",
"-std=c++17",
f"-I{cutlass_include}",
Expand Down
37 changes: 20 additions & 17 deletions python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,10 +679,27 @@ def get_batch_on_arg(arg_name, arg_shape):
elif "attention" in func_name:
headers.append("kernel_forward.h")
data_type = dtype_map[annotations["arg0_dtype"]]

attrs["qkv_layout"] = annotations["qkv_layout"]
if attrs["qkv_layout"] == "default":
attrs["query"] = func_args[0]
attrs["key"] = func_args[1]
attrs["value"] = func_args[2]
attrs["num_queries"] = s = get_dim(annotations["num_queries"], func_args[0], 1)
attrs["num_keys"] = get_dim(annotations["num_keys"], func_args[1], 1)
if len(func_args) > 4: # +1 for workspace, the last arg
attrs["bias"] = func_args[3]
elif attrs["qkv_layout"] == "qkv_stacked":
attrs["qkv"] = func_args[0]
attrs["num_queries"] = s = annotations["num_queries"]
attrs["num_keys"] = annotations["num_keys"]
if len(func_args) > 5: # +1 for workspace, the last arg
attrs["bias"] = func_args[4]
else:
raise NotImplementedError()

attrs["data_type"] = DataTypeTag[data_type]
attrs["num_batches"] = b = annotations["num_batches"]
attrs["num_queries"] = s = annotations["num_queries"]
attrs["num_keys"] = annotations["num_keys"]
attrs["num_heads"] = n = annotations["num_heads"]
attrs["head_dim"] = h = annotations["head_dim"]
attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
Expand All @@ -701,7 +718,7 @@ def get_batch_on_arg(arg_name, arg_shape):
attrs["kQueriesPerBlock"] = 64
attrs["kKeysPerBlock"] = 64
attrs["kSingleValueIteration"] = True
attrs["output_size"] = b * s * n * h_v
attrs["output_size"] = f"{b} * {s} * {n} * {h_v}"
attrs["scale"] = (
float(1 / math.sqrt(h.value)) if annotations["scale"] is None else annotations["scale"]
)
Expand All @@ -712,24 +729,10 @@ def get_batch_on_arg(arg_name, arg_shape):
), "Cutlass may generate nan occasionally when scale == 0.0"
attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
attrs["kSupportsDropout"] = False
attrs["qkv_layout"] = annotations["qkv_layout"]

for arg in func_args:
if "workspace" in arg:
attrs["workspace"] = arg

if attrs["qkv_layout"] == "default":
attrs["query"] = func_args[0]
attrs["key"] = func_args[1]
attrs["value"] = func_args[2]
if len(func_args) > 4: # +1 for workspace, the last arg
attrs["bias"] = func_args[3]
elif attrs["qkv_layout"] == "qkv_stacked":
attrs["qkv"] = func_args[0]
if len(func_args) > 5: # +1 for workspace, the last arg
attrs["bias"] = func_args[4]
else:
raise NotImplementedError()
if "bias" in attrs:
attrs["kSupportsBias"] = True
if len(annotations["bias_shape"]) == 4:
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relax/backend/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,10 @@ def visit_function_(self, f):
out_size_1d = _shape_1d(f.ret_struct_info.shape)
# This needs to be in sync with the actual value that the kernel expects.
workspace_size_bytes = out_size_1d * {"float16": 2, "float32": 4}[out_dtype]
if not isinstance(workspace_size_bytes, (int, tvm.tir.expr.IntImm)):
# Tempororay workaround for dynamic shape workload. Will be removed when
# workspace for dynamic shape workload is implemented.
workspace_size_bytes = 8
return f.with_attr("WorkspaceSize", workspace_size_bytes)

return f
Expand Down
16 changes: 11 additions & 5 deletions src/relax/transform/allocate_workspace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,17 @@ class WorkspaceProvider : ExprMutator {
builder_->GetContextIRModule()->Remove(GetRef<GlobalVar>(gvar));
}

auto gvar = mod_->GetGlobalVar("main");
auto func = Downcast<Function>(mod_->Lookup(gvar));
auto new_func = Function(func->params, VisitExpr(func->body), func->ret_struct_info,
func->is_pure, func->attrs);
builder_->UpdateFunction(gvar, new_func);
for (const auto& [gvar, f] : mod_->functions) {
workspace_var_main_ = Var();
if (!f->IsInstance<relax::FunctionNode>() || f->GetAttr<String>(attr::kCodegen) ||
f->GetAttr<String>(attr::kComposite)) {
continue;
}
auto func = Downcast<Function>(mod_->Lookup(gvar));
auto new_func = Function(func->params, VisitExpr(func->body), func->ret_struct_info,
func->is_pure, func->attrs);
builder_->UpdateFunction(gvar, new_func);
}
return builder_->GetContextIRModule();
}

Expand Down
9 changes: 3 additions & 6 deletions src/runtime/contrib/cublas/cublas_json_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ class CublasJSONRuntime : public JSONRuntimeBase {
const char* type_key() const override { return "cublas_json"; } // May be overridden

void Run() override {
// TODO(masahi): Reuse the same handle across different subgraphs
cublasLtHandle_t handle;
cublasLtCreate(&handle);
auto* entry_ptr = tvm::contrib::CuBlasLtThreadEntry::ThreadLocal();

for (size_t i = 0; i < nodes_.size(); ++i) {
const auto& node = nodes_[i];
Expand Down Expand Up @@ -88,11 +86,10 @@ class CublasJSONRuntime : public JSONRuntimeBase {

auto [a_ptr, b_ptr, bias_ptr] = get_inputs(node, epilogue != CUBLASLT_EPILOGUE_DEFAULT);

tvm::contrib::CallCublasLt(handle, a_ptr, b_ptr, bias_ptr, out_ptr, transa, transb,
epilogue);
tvm::contrib::CallCublasLt(entry_ptr->handle, a_ptr, b_ptr, bias_ptr, out_ptr, transa,
transb, epilogue);
}
}
cublasLtDestroy(handle);
}

private:
Expand Down
13 changes: 13 additions & 0 deletions src/runtime/contrib/cublas/cublas_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,5 +48,18 @@ CuBlasThreadEntry* CuBlasThreadEntry::ThreadLocal() {
return retval;
}

CuBlasLtThreadEntry::CuBlasLtThreadEntry() { CHECK_CUBLAS_ERROR(cublasLtCreate(&handle)); }

CuBlasLtThreadEntry::~CuBlasLtThreadEntry() {
if (handle) {
cublasLtDestroy(handle);
handle = nullptr;
}
}

typedef dmlc::ThreadLocalStore<CuBlasLtThreadEntry> CuBlasLtThreadStore;

CuBlasLtThreadEntry* CuBlasLtThreadEntry::ThreadLocal() { return CuBlasLtThreadStore::Get(); }

} // namespace contrib
} // namespace tvm
7 changes: 7 additions & 0 deletions src/runtime/contrib/cublas/cublas_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ struct CuBlasThreadEntry {
static CuBlasThreadEntry* ThreadLocal();
}; // CuBlasThreadEntry

struct CuBlasLtThreadEntry {
CuBlasLtThreadEntry();
~CuBlasLtThreadEntry();
cublasLtHandle_t handle{nullptr};
static CuBlasLtThreadEntry* ThreadLocal();
}; // CuBlasLtThreadEntry

inline cudaDataType_t GetCudaDataType(DLDataType type) {
if (type.code == kDLInt) {
switch (type.bits) {
Expand Down
74 changes: 57 additions & 17 deletions tests/python/relax/test_codegen_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from tvm.contrib.pickle_memoize import memoize
from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
from tvm.relax.testing import get_relax_matmul_module
from tvm.script import tir as T
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import relax as relax_builder

Expand Down Expand Up @@ -169,9 +169,16 @@ def get_relax_conv2d_module(
return tvm.IRModule({"main": func})


def _to_concrete_shape(symbolic_shape, var_table):
def _to_concrete_shape(symbolic_shape, var_table=None):
if var_table is None:
var_table = {}

result = []
for dim in symbolic_shape:
if isinstance(dim, tuple):
result.append(_to_concrete_shape(dim, var_table))
continue

if not isinstance(dim, tvm.tir.expr.Var):
result.append(dim)
continue
Expand Down Expand Up @@ -543,6 +550,7 @@ def attention_dtype(request):
@pytest.fixture(
params=[
# B, S, N, H
(32, (_vars["a"], 8), 16, (8, 8)),
(32, (8, 8), 16, (8, 8)),
(4, (16, 8), 32, (8, 8)), # s != s_kv
(4, (16, 8), 32, (8, 16)), # h != h_v
Expand All @@ -554,9 +562,9 @@ def attention_size(request):
return request.param


def get_relax_attention_module(q, k, v, bias=None, qk_scale=None, causal=None):
dtype = str(q.dtype)

def get_relax_attention_module(
q_shape, k_shape, v_shape, *, dtype, bias_shape=None, qk_scale=None, causal_mask=None
):
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import relax as relax_builder
from tvm.script.ir_builder import tir as T
Expand All @@ -567,13 +575,15 @@ def get_relax_attention_module(q, k, v, bias=None, qk_scale=None, causal=None):
with IRBuilder() as builder:
with relax_builder.function():
R.func_name("main")
q = R.arg("q", R.Tensor(q.shape, dtype))
k = R.arg("k", R.Tensor(k.shape, dtype))
v = R.arg("v", R.Tensor(v.shape, dtype))
if bias is not None:
bias = R.arg("bias", R.Tensor(bias.shape, dtype))
q = R.arg("q", R.Tensor(q_shape, dtype))
k = R.arg("k", R.Tensor(k_shape, dtype))
v = R.arg("v", R.Tensor(v_shape, dtype))
bias = None
if bias_shape is not None and bias_shape != "none":
bias = R.arg("bias", R.Tensor(bias_shape, dtype))

with R.dataflow() as frame:
result = R.emit(R.nn.attention(q, k, v, bias, qk_scale, causal))
result = R.emit(R.nn.attention(q, k, v, bias, qk_scale, causal_mask))
R.output(result)

R.func_ret_value(frame.output_vars[0])
Expand Down Expand Up @@ -620,11 +630,16 @@ def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, bias_shape, qk_scale, causal,

def test_attention_offload(attention_size, attention_dtype):
b, (s, s_kv), n, (h, h_v) = attention_size
concrete_s, concrete_s_kv = _to_concrete_shape((s, s_kv))
q, k, v, _, ref = get_numpy_attention_ref(
b, s, s_kv, n, h, h_v, "none", "none", "none", attention_dtype
b, concrete_s, concrete_s_kv, n, h, h_v, "none", "none", "none", attention_dtype
)

mod = get_relax_attention_module(q, k, v)
q_shape = (b, s, n, h)
k_shape = (b, s_kv, n, h)
v_shape = (b, s_kv, n, h_v)

mod = get_relax_attention_module(q_shape, k_shape, v_shape, dtype=attention_dtype)
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)

tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
Expand All @@ -649,11 +664,19 @@ def attention_bias_size(request):

def test_attention_bias_offload(attention_bias_size):
b, (s, s_kv), n, (h, h_v), bias_shape = attention_bias_size
concrete_s, concrete_s_kv, concrete_bias_shape = _to_concrete_shape((s, s_kv, bias_shape))

q, k, v, bias, ref = get_numpy_attention_ref(
b, s, s_kv, n, h, h_v, bias_shape, "none", "none", "float32"
b, concrete_s, concrete_s_kv, n, h, h_v, concrete_bias_shape, "none", "none", "float32"
)

mod = get_relax_attention_module(q, k, v, bias)
q_shape = (b, s, n, h)
k_shape = (b, s_kv, n, h)
v_shape = (b, s_kv, n, h_v)

mod = get_relax_attention_module(
q_shape, k_shape, v_shape, bias_shape=bias_shape, dtype="float32"
)
out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=3)

tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
Expand Down Expand Up @@ -681,7 +704,13 @@ def test_attention_scale_offload(attention_scale_size, attention_scale):
b, s, s_kv, n, h, h_v, bias_shape, attention_scale, "none", "float32"
)

mod = get_relax_attention_module(q, k, v, bias, attention_scale)
q_shape = (b, s, n, h)
k_shape = (b, s_kv, n, h)
v_shape = (b, s_kv, n, h_v)

mod = get_relax_attention_module(
q_shape, k_shape, v_shape, dtype="float32", bias_shape=bias_shape, qk_scale=attention_scale
)
if bias is None:
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)
else:
Expand Down Expand Up @@ -712,7 +741,18 @@ def test_attention_causal_offload(attention_causal_size, attention_causal):
b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float32"
)

mod = get_relax_attention_module(q, k, v, bias, None, attention_causal)
q_shape = (b, s, n, h)
k_shape = (b, s_kv, n, h)
v_shape = (b, s_kv, n, h_v)

mod = get_relax_attention_module(
q_shape,
k_shape,
v_shape,
dtype="float32",
bias_shape=bias_shape,
causal_mask=attention_causal,
)
if bias is None:
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)
else:
Expand Down
Loading