Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Sep 26, 2023
1 parent ae657b7 commit 5d01fd1
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 19 deletions.
26 changes: 13 additions & 13 deletions python/tvm/contrib/cutlass/attention_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,10 @@ def instantiate_flash_attention_template(attrs):
int k_head_stride = ${head_dim};
int v_head_stride = ${head_dim};
int o_head_stride = ${head_dim};
int q_row_stride = q_head_stride * ${num_heads};
int k_row_stride = k_head_stride * ${num_heads};
int v_row_stride = v_head_stride * ${num_heads};
int o_row_stride = o_head_stride * ${num_heads};
int q_row_stride = q_head_stride * ${num_q_heads};
int k_row_stride = k_head_stride * ${num_kv_heads};
int v_row_stride = v_head_stride * ${num_kv_heads};
int o_row_stride = o_head_stride * ${num_q_heads};
int q_batch_stride = q_row_stride * ${num_queries};
int k_batch_stride = k_row_stride * ${num_keys};
int v_batch_stride = v_row_stride * ${num_keys};
Expand All @@ -190,8 +190,8 @@ def instantiate_flash_attention_template(attrs):
${num_batches},
${num_queries},
${num_keys},
${num_heads},
${num_heads},
${num_q_heads},
${num_kv_heads},
${head_dim},
q_batch_stride,
k_batch_stride,
Expand All @@ -215,9 +215,9 @@ def instantiate_flash_attention_template(attrs):
int k_head_stride = ${head_dim};
int v_head_stride = ${head_dim};
int o_head_stride = ${head_dim};
int row_stride = q_head_stride * ${num_heads} +
k_head_stride * ${num_heads} +
v_head_stride * ${num_heads};
int row_stride = q_head_stride * ${num_q_heads} +
k_head_stride * ${num_kv_heads} +
v_head_stride * ${num_kv_heads};
int q_row_stride = row_stride;
int k_row_stride = row_stride;
int v_row_stride = row_stride;
Expand All @@ -234,14 +234,14 @@ def instantiate_flash_attention_template(attrs):
flash_attn::flash_attention_forward(
static_cast<const cutlass::half_t*>(${qkv}->data),
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_heads},
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_heads} * 2,
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * ${num_q_heads},
static_cast<const cutlass::half_t*>(${qkv}->data) + ${head_dim} * (${num_q_heads} + ${num_kv_heads})
static_cast<cutlass::half_t*>(out0->data),
${num_batches},
${num_queries},
${num_keys},
${num_heads},
${num_heads},
${num_q_heads},
${num_kv_heads},
${head_dim},
q_batch_stride,
k_batch_stride,
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,8 +909,8 @@ def handle_attention(self, f, op_type):

out_shape = signature["ret_shape"]
out_dtype = signature["ret_dtype"]
num_batches, num_queries, num_heads, head_dim = q_shape
_, num_keys, _, _ = k_shape
num_batches, num_queries, num_q_heads, head_dim = q_shape
_, num_keys, num_kv_heads, _ = k_shape
_, _, _, head_dim_value = v_shape
scale = op_attrs.scale

Expand All @@ -931,7 +931,8 @@ def handle_attention(self, f, op_type):
"num_batches": num_batches,
"num_queries": num_queries,
"num_keys": num_keys,
"num_heads": num_heads,
"num_q_heads": num_q_heads,
"num_kv_heads": num_kv_heads,
"head_dim": head_dim,
"head_dim_value": head_dim_value,
"scale": scale,
Expand Down
9 changes: 8 additions & 1 deletion python/tvm/contrib/cutlass/gen_tensor_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,6 @@ def get_batch_on_arg(arg_name, arg_shape):

attrs["data_type"] = DataTypeTag[data_type]
attrs["num_batches"] = b = annotations["num_batches"]
attrs["num_heads"] = n = annotations["num_heads"]
attrs["head_dim"] = h = annotations["head_dim"]
attrs["head_dim_value"] = h_v = annotations["head_dim_value"]
attrs["kMaxK"] = max(int(attrs["head_dim"]), int(attrs["head_dim_value"]))
Expand All @@ -766,13 +765,21 @@ def get_batch_on_arg(arg_name, arg_shape):
and int(annotations["arch"]) >= 80
)

print(int(attrs["head_dim"]) <= 256, int(attrs["head_dim"]) % 8 == 0, int(attrs["head_dim"]) == int(attrs["head_dim_value"]),int(annotations["arch"]) >= 80, annotations["ret_dtype"] == "float16", "bias" not in attrs, int(annotations["arch"]) >= 80)


if use_flash:
headers.append("flash.h")
attrs["is_causal"] = int(annotations["custom_mask_type"]) > 0
attrs["num_q_heads"] = annotations["num_q_heads"]
attrs["num_kv_heads"] = annotations["num_kv_heads"]
code = instantiate_flash_attention_template(attrs)
else:
headers.append("kernel_forward.h")

assert annotations["num_q_heads"] == annotations["num_kv_heads"]
attrs["num_heads"] = n = annotations["num_q_heads"]

data_type_size = DataTypeSize[data_type]
if (data_type_size * h // 8) % 16 == 0 and (data_type_size * h_v // 8) % 16 == 0:
attrs["kIsAligned"] = True
Expand Down
5 changes: 3 additions & 2 deletions tests/python/relax/test_codegen_cutlass.py
Original file line number Diff line number Diff line change
Expand Up @@ -1968,7 +1968,7 @@ def rewrite_attention(f):

def callback(_, matchings):
return R.nn.attention(
matchings[Q], matchings[K], matchings[V], causal_mask="BottomRight"
matchings[Q], matchings[K], matchings[V],
)

return rewrite_call(pattern, callback, f)
Expand Down Expand Up @@ -2007,7 +2007,8 @@ def main(

Module["main"] = rewrite_attention(Module["main"])
mod = partition_for_cutlass(Module)
print(mod)
codegen_pass = relax.transform.RunCodegen({"cutlass": {"sm": 80, "find_first_valid": True}})
print(codegen_pass(mod))

if __name__ == "__main__":
# tvm.testing.main()
Expand Down

0 comments on commit 5d01fd1

Please sign in to comment.