diff --git a/python/tvm/contrib/cutlass/attention_operation.py b/python/tvm/contrib/cutlass/attention_operation.py index 8a96e70fe478..55c1ccd6163d 100644 --- a/python/tvm/contrib/cutlass/attention_operation.py +++ b/python/tvm/contrib/cutlass/attention_operation.py @@ -96,9 +96,20 @@ def instantiate_attention_template(attrs): typename Attention::Params p; p.logsumexp_ptr = nullptr; p.output_ptr = reinterpret_cast(out0->data); + p.output_accum_ptr = nullptr; + uint64_t accumulator_buf_size = ${output_size} * sizeof(Attention::output_accum_t); + bool accumulator_buf_allocated = false; if (Attention::kNeedsOutputAccumulatorBuffer) { - p.output_accum_ptr = static_cast(${workspace}->data); + if (accumulator_buf_size <= ${workspace}->shape[0]) { + p.output_accum_ptr = static_cast(${workspace}->data); + } else { + accumulator_buf_size = true; + cudaMalloc( + &p.output_accum_ptr, + accumulator_buf_size + ); + } } p.num_heads = ${num_heads}; // N @@ -129,6 +140,10 @@ def instantiate_attention_template(attrs): CHECK(Attention::check_supported(p)); kernel_fn<<>>(p); + + if (accumulator_buf_allocated) { + cudaFree(p.output_accum_ptr); + } """ template = substitute_template( diff --git a/python/tvm/contrib/cutlass/gen_tensor_op.py b/python/tvm/contrib/cutlass/gen_tensor_op.py index be3bb289cfaa..48e285998ad7 100644 --- a/python/tvm/contrib/cutlass/gen_tensor_op.py +++ b/python/tvm/contrib/cutlass/gen_tensor_op.py @@ -679,10 +679,27 @@ def get_batch_on_arg(arg_name, arg_shape): elif "attention" in func_name: headers.append("kernel_forward.h") data_type = dtype_map[annotations["arg0_dtype"]] + + attrs["qkv_layout"] = annotations["qkv_layout"] + if attrs["qkv_layout"] == "default": + attrs["query"] = func_args[0] + attrs["key"] = func_args[1] + attrs["value"] = func_args[2] + attrs["num_queries"] = s = get_dim(annotations["num_queries"], func_args[0], 1) + attrs["num_keys"] = get_dim(annotations["num_keys"], func_args[1], 1) + if len(func_args) > 4: # +1 for workspace, the last arg + attrs["bias"] = func_args[3] + elif attrs["qkv_layout"] == "qkv_stacked": + attrs["qkv"] = func_args[0] + attrs["num_queries"] = s = annotations["num_queries"] + attrs["num_keys"] = annotations["num_keys"] + if len(func_args) > 5: # +1 for workspace, the last arg + attrs["bias"] = func_args[4] + else: + raise NotImplementedError() + attrs["data_type"] = DataTypeTag[data_type] attrs["num_batches"] = b = annotations["num_batches"] - attrs["num_queries"] = s = annotations["num_queries"] - attrs["num_keys"] = annotations["num_keys"] attrs["num_heads"] = n = annotations["num_heads"] attrs["head_dim"] = h = annotations["head_dim"] attrs["head_dim_value"] = h_v = annotations["head_dim_value"] @@ -701,7 +718,7 @@ def get_batch_on_arg(arg_name, arg_shape): attrs["kQueriesPerBlock"] = 64 attrs["kKeysPerBlock"] = 64 attrs["kSingleValueIteration"] = True - attrs["output_size"] = b * s * n * h_v + attrs["output_size"] = f"{b} * {s} * {n} * {h_v}" attrs["scale"] = ( float(1 / math.sqrt(h.value)) if annotations["scale"] is None else annotations["scale"] ) @@ -712,24 +729,10 @@ def get_batch_on_arg(arg_name, arg_shape): ), "Cutlass may generate nan occasionally when scale == 0.0" attrs["arch"] = "cutlass::arch::Sm{}".format(annotations["arch"]) attrs["kSupportsDropout"] = False - attrs["qkv_layout"] = annotations["qkv_layout"] for arg in func_args: if "workspace" in arg: attrs["workspace"] = arg - - if attrs["qkv_layout"] == "default": - attrs["query"] = func_args[0] - attrs["key"] = func_args[1] - attrs["value"] = func_args[2] - if len(func_args) > 4: # +1 for workspace, the last arg - attrs["bias"] = func_args[3] - elif attrs["qkv_layout"] == "qkv_stacked": - attrs["qkv"] = func_args[0] - if len(func_args) > 5: # +1 for workspace, the last arg - attrs["bias"] = func_args[4] - else: - raise NotImplementedError() if "bias" in attrs: attrs["kSupportsBias"] = True if len(annotations["bias_shape"]) == 4: diff --git a/python/tvm/relax/backend/contrib/cutlass.py b/python/tvm/relax/backend/contrib/cutlass.py index dffd7c401cbe..bdd230f1893a 100644 --- a/python/tvm/relax/backend/contrib/cutlass.py +++ b/python/tvm/relax/backend/contrib/cutlass.py @@ -412,6 +412,10 @@ def visit_function_(self, f): out_size_1d = _shape_1d(f.ret_struct_info.shape) # This needs to be in sync with the actual value that the kernel expects. workspace_size_bytes = out_size_1d * {"float16": 2, "float32": 4}[out_dtype] + if not isinstance(workspace_size_bytes, (int, tvm.tir.expr.IntImm)): + # Tempororay workaround for dynamic shape workload. Will be removed when + # workspace for dynamic shape workload is implemented. + workspace_size_bytes = 8 return f.with_attr("WorkspaceSize", workspace_size_bytes) return f diff --git a/tests/python/relax/test_codegen_cutlass.py b/tests/python/relax/test_codegen_cutlass.py index 8a1675ad3593..7bbdd630a565 100644 --- a/tests/python/relax/test_codegen_cutlass.py +++ b/tests/python/relax/test_codegen_cutlass.py @@ -25,9 +25,9 @@ from tvm.contrib.pickle_memoize import memoize from tvm.relax.backend.contrib.cutlass import partition_for_cutlass from tvm.relax.testing import get_relax_matmul_module -from tvm.script import tir as T from tvm.script import ir as I from tvm.script import relax as R +from tvm.script import tir as T from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import relax as relax_builder @@ -169,9 +169,16 @@ def get_relax_conv2d_module( return tvm.IRModule({"main": func}) -def _to_concrete_shape(symbolic_shape, var_table): +def _to_concrete_shape(symbolic_shape, var_table=None): + if var_table is None: + var_table = {} + result = [] for dim in symbolic_shape: + if isinstance(dim, tuple): + result.append(_to_concrete_shape(dim, var_table)) + continue + if not isinstance(dim, tvm.tir.expr.Var): result.append(dim) continue @@ -543,6 +550,7 @@ def attention_dtype(request): @pytest.fixture( params=[ # B, S, N, H + (32, (_vars["a"], 8), 16, (8, 8)), (32, (8, 8), 16, (8, 8)), (4, (16, 8), 32, (8, 8)), # s != s_kv (4, (16, 8), 32, (8, 16)), # h != h_v @@ -554,9 +562,9 @@ def attention_size(request): return request.param -def get_relax_attention_module(q, k, v, bias=None, qk_scale=None, causal=None): - dtype = str(q.dtype) - +def get_relax_attention_module( + q_shape, k_shape, v_shape, *, dtype, bias_shape=None, qk_scale=None, causal_mask=None +): from tvm.script.ir_builder import IRBuilder from tvm.script.ir_builder import relax as relax_builder from tvm.script.ir_builder import tir as T @@ -567,13 +575,15 @@ def get_relax_attention_module(q, k, v, bias=None, qk_scale=None, causal=None): with IRBuilder() as builder: with relax_builder.function(): R.func_name("main") - q = R.arg("q", R.Tensor(q.shape, dtype)) - k = R.arg("k", R.Tensor(k.shape, dtype)) - v = R.arg("v", R.Tensor(v.shape, dtype)) - if bias is not None: - bias = R.arg("bias", R.Tensor(bias.shape, dtype)) + q = R.arg("q", R.Tensor(q_shape, dtype)) + k = R.arg("k", R.Tensor(k_shape, dtype)) + v = R.arg("v", R.Tensor(v_shape, dtype)) + bias = None + if bias_shape is not None and bias_shape != "none": + bias = R.arg("bias", R.Tensor(bias_shape, dtype)) + with R.dataflow() as frame: - result = R.emit(R.nn.attention(q, k, v, bias, qk_scale, causal)) + result = R.emit(R.nn.attention(q, k, v, bias, qk_scale, causal_mask)) R.output(result) R.func_ret_value(frame.output_vars[0]) @@ -620,11 +630,16 @@ def get_numpy_attention_ref(b, s, s_kv, n, h, h_v, bias_shape, qk_scale, causal, def test_attention_offload(attention_size, attention_dtype): b, (s, s_kv), n, (h, h_v) = attention_size + concrete_s, concrete_s_kv = _to_concrete_shape((s, s_kv)) q, k, v, _, ref = get_numpy_attention_ref( - b, s, s_kv, n, h, h_v, "none", "none", "none", attention_dtype + b, concrete_s, concrete_s_kv, n, h, h_v, "none", "none", "none", attention_dtype ) - mod = get_relax_attention_module(q, k, v) + q_shape = (b, s, n, h) + k_shape = (b, s_kv, n, h) + v_shape = (b, s_kv, n, h_v) + + mod = get_relax_attention_module(q_shape, k_shape, v_shape, dtype=attention_dtype) out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) @@ -649,11 +664,19 @@ def attention_bias_size(request): def test_attention_bias_offload(attention_bias_size): b, (s, s_kv), n, (h, h_v), bias_shape = attention_bias_size + concrete_s, concrete_s_kv, concrete_bias_shape = _to_concrete_shape((s, s_kv, bias_shape)) + q, k, v, bias, ref = get_numpy_attention_ref( - b, s, s_kv, n, h, h_v, bias_shape, "none", "none", "float32" + b, concrete_s, concrete_s_kv, n, h, h_v, concrete_bias_shape, "none", "none", "float32" ) - mod = get_relax_attention_module(q, k, v, bias) + q_shape = (b, s, n, h) + k_shape = (b, s_kv, n, h) + v_shape = (b, s_kv, n, h_v) + + mod = get_relax_attention_module( + q_shape, k_shape, v_shape, bias_shape=bias_shape, dtype="float32" + ) out = get_result_with_relax_cutlass_offload(mod, q, k, v, bias, num_final_bindings=3) tvm.testing.assert_allclose(out, ref, rtol=1e-2, atol=1e-2) @@ -681,7 +704,13 @@ def test_attention_scale_offload(attention_scale_size, attention_scale): b, s, s_kv, n, h, h_v, bias_shape, attention_scale, "none", "float32" ) - mod = get_relax_attention_module(q, k, v, bias, attention_scale) + q_shape = (b, s, n, h) + k_shape = (b, s_kv, n, h) + v_shape = (b, s_kv, n, h_v) + + mod = get_relax_attention_module( + q_shape, k_shape, v_shape, dtype="float32", bias_shape=bias_shape, qk_scale=attention_scale + ) if bias is None: out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3) else: @@ -712,7 +741,18 @@ def test_attention_causal_offload(attention_causal_size, attention_causal): b, s, s_kv, n, h, h_v, bias_shape, "none", attention_causal, "float32" ) - mod = get_relax_attention_module(q, k, v, bias, None, attention_causal) + q_shape = (b, s, n, h) + k_shape = (b, s_kv, n, h) + v_shape = (b, s_kv, n, h_v) + + mod = get_relax_attention_module( + q_shape, + k_shape, + v_shape, + dtype="float32", + bias_shape=bias_shape, + causal_mask=attention_causal, + ) if bias is None: out = get_result_with_relax_cutlass_offload(mod, q, k, v, num_final_bindings=3) else: