Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity] Cutlass attention with dynamic sequence length #15028

Merged
merged 2 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 16 additions & 1 deletion python/tvm/contrib/cutlass/attention_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,20 @@ def instantiate_attention_template(attrs):
typename Attention::Params p;
p.logsumexp_ptr = nullptr;
p.output_ptr = reinterpret_cast<T *>(out0->data);

p.output_accum_ptr = nullptr;
uint64_t accumulator_buf_size = ${output_size} * sizeof(Attention::output_accum_t);
bool accumulator_buf_allocated = false;
if (Attention::kNeedsOutputAccumulatorBuffer) {
p.output_accum_ptr = static_cast<float*>(${workspace}->data);
if (accumulator_buf_size <= ${workspace}->shape[0]) {
p.output_accum_ptr = static_cast<float*>(${workspace}->data);
} else {
accumulator_buf_size = true;
cudaMalloc(
&p.output_accum_ptr,
accumulator_buf_size
);
}
}

p.num_heads = ${num_heads}; // N
Expand Down Expand Up @@ -129,6 +140,10 @@ def instantiate_attention_template(attrs):

CHECK(Attention::check_supported(p));
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes>>>(p);

if (accumulator_buf_allocated) {
cudaFree(p.output_accum_ptr);
}
"""

template = substitute_template(
Expand Down
37 changes: 20 additions & 17 deletions python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,10 +679,27 @@ def get_batch_on_arg(arg_name, arg_shape):
elif "attention" in func_name:
headers.append("kernel_forward.h")
data_type = dtype_map[annotations["arg0_dtype"]]

attrs["qkv_layout"] = annotations["qkv_layout"]
if attrs["qkv_layout"] == "default":
attrs["query"] = func_args[0]
attrs["key"] = func_args[1]
attrs["value"] = func_args[2]
attrs["num_queries"] = s = get_dim(annotations["num_queries"], func_args[0], 1)
attrs["num_keys"] = get_dim(annotations["num_keys"], func_args[1], 1)
if len(func_args) > 4: # +1 for workspace, the last arg
attrs["bias"] = func_args[3]
elif attrs["qkv_layout"] == "qkv_stacked":
attrs["qkv"] = func_args[0]
attrs["num_queries"] = s = annotations["num_queries"]
attrs["num_keys"] = annotations["num_keys"]
if len(func_args) > 5: # +1 for workspace, the last arg
attrs["bias"] = func_args[4]
else:
raise NotImplementedError()

attrs["data_type"] = DataTypeTag[data_type]
attrs["num_batches"] = b = annotations["num_batches"]
attrs["num_queries"] = s = annotations["num_queries"]
attrs["num_keys"] = annotations["num_keys"]
attrs["num_heads"] = n = annotations["num_heads"]
attrs["head_dim"] = h = annotations["head_dim"]
attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
Expand All @@ -701,7 +718,7 @@ def get_batch_on_arg(arg_name, arg_shape):
attrs["kQueriesPerBlock"] = 64
attrs["kKeysPerBlock"] = 64
attrs["kSingleValueIteration"] = True
attrs["output_size"] = b * s * n * h_v
attrs["output_size"] = f"{b} * {s} * {n} * {h_v}"
attrs["scale"] = (
float(1 / math.sqrt(h.value)) if annotations["scale"] is None else annotations["scale"]
)
Expand All @@ -712,24 +729,10 @@ def get_batch_on_arg(arg_name, arg_shape):
), "Cutlass may generate nan occasionally when scale == 0.0"
attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"])
attrs["kSupportsDropout"] = False
attrs["qkv_layout"] = annotations["qkv_layout"]

for arg in func_args:
if "workspace" in arg:
attrs["workspace"] = arg

if attrs["qkv_layout"] == "default":
attrs["query"] = func_args[0]
attrs["key"] = func_args[1]
attrs["value"] = func_args[2]
if len(func_args) > 4: # +1 for workspace, the last arg
attrs["bias"] = func_args[3]
elif attrs["qkv_layout"] == "qkv_stacked":
attrs["qkv"] = func_args[0]
if len(func_args) > 5: # +1 for workspace, the last arg
attrs["bias"] = func_args[4]
else:
raise NotImplementedError()
if "bias" in attrs:
attrs["kSupportsBias"] = True
if len(annotations["bias_shape"]) == 4:
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/relax/backend/contrib/cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,10 @@ def visit_function_(self, f):
out_size_1d = _shape_1d(f.ret_struct_info.shape)
# This needs to be in sync with the actual value that the kernel expects.
workspace_size_bytes = out_size_1d * {"float16": 2, "float32": 4}[out_dtype]
if not isinstance(workspace_size_bytes, (int, tvm.tir.expr.IntImm)):
# Tempororay workaround for dynamic shape workload. Will be removed when
# workspace for dynamic shape workload is implemented.
workspace_size_bytes = 8
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How hard would it to be to support workspace alloc for dynamic attention?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approach in my mind is to make the max_workspace_size_ of WorkspaceProvider a PrimExpr isntead of IntImm, which consists of a series of max expr to get the actual max workspace size from symbolic shape variables.

I tried to doing this in the most straightforward way and got This IR is not well formed: Symbolic Var a presents in different functions in the same Module.. Apparently it needs more sophisticated analysis to replace the shape variables with valid ones as the workspace size is propagated from composite function to its caller. Also I am not sure if the arg of alloc_tensor can be a non-constant expr too.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for curiosity, what is the rationale behind workspace_size_bytes = 8?

The approach in my mind is to make the max_workspace_size_ of WorkspaceProvider a PrimExpr isntead of IntImm, which consists of a series of max expr to get the actual max workspace size from symbolic shape variables.

Would you elaborate a little more? Did you try to do more sophisticated approach than using the annotation like R.func_attr({"tir_var_upper_bound": {"m": 2048, "n": 2048}})?

I am not sure if the arg of alloc_tensor can be a non-constant expr too.

aloc_tensor can take symbolic shape. See https://github.com/apache/tvm/blob/unity/tests/python/relax/test_transform_static_plan_block_memory.py#L958

Copy link
Member

@masahi masahi Jun 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

workspace_size_bytes = 8

It is just for a temp workaround, it will cause cudaMalloc to be always called.

return f.with_attr("WorkspaceSize", workspace_size_bytes)

return f
Expand Down
74 changes: 57 additions & 17 deletions tests/python/relax/test_codegen_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from tvm.contrib.pickle_memoize import memoize
from tvm.relax.backend.contrib.cutlass import partition_for_cutlass
from tvm.relax.testing import get_relax_matmul_module
from tvm.script import tir as T
from tvm.script import ir as I
from tvm.script import relax as R
from tvm.script import tir as T
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import relax as relax_builder

Expand Down Expand Up @@ -169,9 +169,16 @@ def get_relax_conv2d_module(
return tvm.IRModule({"main": func})


def _to_concrete_shape(symbolic_shape, var_table):
def _to_concrete_shape(symbolic_shape, var_table=None):
if var_table is None:
var_table = {}

result = []
for dim in symbolic_shape:
if isinstance(dim, tuple):
result.append(_to_concrete_shape(dim, var_table))
continue

if not isinstance(dim, tvm.tir.expr.Var):
result.append(dim)
continue
Expand Down Expand Up @@ -543,6 +550,7 @@ def attention_dtype(request):
@pytest.fixture(
params=[
# B, S, N, H
(32, (_vars["a"], 8), 16, (8, 8)),
(32, (8, 8), 16, (8, 8)),
(4, (16, 8), 32, (8, 8)), # s != s_kv
(4, (16, 8), 32, (8, 16)), # h != h_v
Expand All @@ -554,9 +562,9 @@ def attention_size(request):
return request.param


def get_relax_attention_module(q, k, v, bias=None, qk_scale=None, causal=None):
dtype = str(q.dtype)

def get_relax_attention_module(
q_shape, k_shape, v_shape, *, dtype, bias_shape=None, qk_scale=None, causal_mask=None
):
from tvm.script.ir_builder import IRBuilder
from tvm.script.ir_builder import relax as relax_builder
from tvm.script.ir_builder import tir as T
Expand All @@ -567,13 +575,15 @@ def get_relax_attention_module(q, k, v, bias=None, qk_scale=None, causal=None):
with IRBuilder() as builder:
with relax_builder.function():
R.func_name("main")
q = R.arg("q", R.Tensor(q.shape, dtype))
k = R.arg("k", R.Tensor(k.shape, dtype))
v = R.arg("v", R.Tensor(v.shape, dtype))
if bias is not None:
bias = R.arg("bias", R.Tensor(bias.shape, dtype))
q = R.arg("q", R.Tensor(q_shape, dtype))
k = R.arg("k", R.Tensor(k_shape, dtype))
v = R.arg("v", R.Tensor(v_shape, dtype))
bias = None
if bias_shape is not None and bias_shape != "none":
bias = R.arg("bias", R.Tensor(bias_shape, dtype))

with R.dataflow() as frame:
result = R.emit(R.nn.attention(q, k, v, bias, qk_scale, causal))
result = R.emit(R.nn.attention(q, k, v, bias, qk_scale, causal_mask))
R.output(result)

R.func_ret_value(frame.output_vars[0])
Expand Down Expand Up @@ -620,11 +630,16 @@ def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, bias_shape, qk_scale, causal,

def test_attention_offload(attention_size, attention_dtype):
b, (s, s_kv), n, (h, h_v) = attention_size
concrete_s, concrete_s_kv = _to_concrete_shape((s, s_kv))
q, k, v, _, ref = get_numpy_attention_ref(
b, s, s_kv, n, h, h_v, "none", "none", "none", attention_dtype
b, concrete_s, concrete_s_kv, n, h, h_v, "none", "none", "none", attention_dtype
)

mod = get_relax_attention_module(q, k, v)
q_shape = (b, s, n, h)
k_shape = (b, s_kv, n, h)
v_shape = (b, s_kv, n, h_v)

mod = get_relax_attention_module(q_shape, k_shape, v_shape, dtype=attention_dtype)
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)

tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
Expand All @@ -649,11 +664,19 @@ def attention_bias_size(request):

def test_attention_bias_offload(attention_bias_size):
b, (s, s_kv), n, (h, h_v), bias_shape = attention_bias_size
concrete_s, concrete_s_kv, concrete_bias_shape = _to_concrete_shape((s, s_kv, bias_shape))

q, k, v, bias, ref = get_numpy_attention_ref(
b, s, s_kv, n, h, h_v, bias_shape, "none", "none", "float32"
b, concrete_s, concrete_s_kv, n, h, h_v, concrete_bias_shape, "none", "none", "float32"
)

mod = get_relax_attention_module(q, k, v, bias)
q_shape = (b, s, n, h)
k_shape = (b, s_kv, n, h)
v_shape = (b, s_kv, n, h_v)

mod = get_relax_attention_module(
q_shape, k_shape, v_shape, bias_shape=bias_shape, dtype="float32"
)
out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=3)

tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2)
Expand Down Expand Up @@ -681,7 +704,13 @@ def test_attention_scale_offload(attention_scale_size, attention_scale):
b, s, s_kv, n, h, h_v, bias_shape, attention_scale, "none", "float32"
)

mod = get_relax_attention_module(q, k, v, bias, attention_scale)
q_shape = (b, s, n, h)
k_shape = (b, s_kv, n, h)
v_shape = (b, s_kv, n, h_v)

mod = get_relax_attention_module(
q_shape, k_shape, v_shape, dtype="float32", bias_shape=bias_shape, qk_scale=attention_scale
)
if bias is None:
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)
else:
Expand Down Expand Up @@ -712,7 +741,18 @@ def test_attention_causal_offload(attention_causal_size, attention_causal):
b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float32"
)

mod = get_relax_attention_module(q, k, v, bias, None, attention_causal)
q_shape = (b, s, n, h)
k_shape = (b, s_kv, n, h)
v_shape = (b, s_kv, n, h_v)

mod = get_relax_attention_module(
q_shape,
k_shape,
v_shape,
dtype="float32",
bias_shape=bias_shape,
causal_mask=attention_causal,
)
if bias is None:
out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3)
else:
Expand Down