Skip to content

Commit

Permalink
v0.1.2
Browse files Browse the repository at this point in the history
  • Loading branch information
xiayuqing0622 committed Dec 13, 2024
1 parent 0ccf429 commit 0878e3f
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 8 deletions.
20 changes: 14 additions & 6 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
} else {
v_padded = v;
}
// // Otherwise the kernel will be launched from cuda:0 device
// // Cast to char to avoid compiler warning about narrowing
// at::cuda::CUDAGuard device_guard{(char)q.get_device()};

// auto opts = q.options();
at::Tensor out;
if (out_.has_value()) {
out = out_.value();
Expand All @@ -459,11 +463,11 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
out = out.reshape({batch_size, num_heads_maxkv, ngroups, v_head_size_og}).transpose(1, 2);
}
if (v_head_size_og % 8 != 0) {
out = torch::empty({batch_size, seqlen_q, num_heads, v_head_size_og}, q.options());
// out = torch::empty({batch_size, num_heads, seqlen_q, v_head_size_og}, q.options()).transpose(1, 2);
out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8}));
}
} else {
out = torch::empty({batch_size, seqlen_q, num_heads, v_head_size_og}, q.options());
out = torch::empty({batch_size, num_heads, seqlen_q, v_head_size_og}, q.options()).transpose(1, 2);
if (v_head_size_og % 8 != 0) {
out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8}));
}
Expand Down Expand Up @@ -623,7 +627,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
CHECK_CONTIGUOUS(cu_seqlens_k);

const auto sizes = q.sizes();
const int v_head_size_og = v.sizes()[2];
const int v_head_size_og = paged_KV ? v.sizes()[3] : v.sizes()[2];
const int batch_size = cu_seqlens_q.numel() - 1;
int num_heads = sizes[1];
const int head_size_og = sizes[2];
Expand Down Expand Up @@ -710,14 +714,19 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
TORCH_CHECK(out.stride(-1) == 1, "Output tensor must have contiguous last dimension");
CHECK_SHAPE(out, sizes[0], sizes[1], v_head_size_og);
if (seqlenq_ngroups_swapped) {
out = out.reshape({batch_size, num_heads_maxkv, ngroups, v_head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_maxkv, head_size_og});
out = out.reshape({batch_size, num_heads_maxkv, ngroups, v_head_size_og}).transpose(1, 2).reshape({batch_size * ngroups, num_heads_maxkv, v_head_size_og});
}
if (v_head_size_og % 8 != 0) {
out = torch::empty({total_q, num_heads, v_head_size_og}, q.options());
// out = torch::empty({total_q, num_heads, v_head_size_og}, q.options());
out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8}));
}
} else {
if (seqlenq_ngroups_swapped) {
out = torch::empty({batch_size, num_heads_maxkv, ngroups, v_head_size_og}, q.options()).transpose(1, 2).reshape({batch_size * ngroups, num_heads_maxkv, v_head_size_og});
}
else {
out = torch::empty({total_q, num_heads, v_head_size_og}, q.options());
}
if (v_head_size_og % 8 != 0) {
out = torch::nn::functional::pad(out, torch::nn::functional::PadFuncOptions({0, 8 - v_head_size_og % 8}));
}
Expand Down Expand Up @@ -1024,7 +1033,6 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
}

Flash_bwd_params params;

set_params_dgrad(params,
batch_size,
seqlen_q, seqlen_k,
Expand Down
2 changes: 1 addition & 1 deletion flex_head_fa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.1.1" # flash attn __version__ = "2.6.3"
__version__ = "0.1.2" # flash attn __version__ = "2.6.3"

from flex_head_fa.flash_attn_interface import (
flash_attn_func,
Expand Down
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,7 @@ def get_wheel_url():
wheel_filename = f"{PACKAGE_NAME}-{flash_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"

wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{flash_version}", wheel_name=wheel_filename)
print(wheel_url)
return wheel_url, wheel_filename


Expand Down Expand Up @@ -484,7 +485,8 @@ def run(self):
impl_tag, abi_tag, plat_tag = self.get_tag()
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"

wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
# wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
wheel_path = os.path.join(self.dist_dir, wheel_filename)
print("Raw wheel path", wheel_path)
os.rename(wheel_filename, wheel_path)
except (urllib.error.HTTPError, urllib.error.URLError):
Expand Down

0 comments on commit 0878e3f

Please sign in to comment.