Skip to content

Commit

Permalink
[Submodule] Turning flash-attention integration into 3rd party submod (
Browse files Browse the repository at this point in the history
…pytorch#146372)

Summary:
Pull Request resolved: pytorch#146372

Pull Request resolved: pytorch#144120

# Summary

### Sticky points

Cuda-graph rng handling has changed / deviated from original implementation. We will be left with a dangling 'offset' val and confusing naming due to BC

## Dependencies
- Flash PR: Dao-AILab/flash-attention#1419

### Other Points
- The BC linter is complaining about losing generate.py and its functions which is not real BC surface
cc albanD

imported-using-ghimport

Test Plan:
Imported from OSS

Building in dev
`buck build @//mode/dev-nosan -c fbcode.nvcc_arch=h100a  //caffe2:ATen-cu --show-full-output    `

I and Nming the .so I do see that the flash symbols are correctly named:
```
0000000001c3dfb0 t pytorch_flash::run_mha_bwd(pytorch_flash::Flash_bwd_params&, CUstream_st*)::$_0::operator()() const::{lambda()pytorch#1}::operator()() const::{lambda()pytorch#1}::operator()() const::{lambda()pytorch#7}::operator()() const
0000000001c36080 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()pytorch#2}::operator()() const::{lambda()pytorch#1}::operator()() const::{lambda()pytorch#6}::operator()() const
0000000001c360e0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()pytorch#2}::operator()() const::{lambda()pytorch#1}::operator()() const::{lambda()pytorch#7}::operator()() const
0000000001c35fc0 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()pytorch#1}::operator()() const::{lambda()pytorch#1}::operator()() const::{lambda()pytorch#6}::operator()() const
0000000001c36020 t pytorch_flash::run_mha_fwd(pytorch_flash::Flash_fwd_params&, CUstream_st*, bool)::$_0::operator()() const::{lambda()pytorch#1}::operator()() const::{lambda()pytorch#1}::operator()() const::{lambda()pytorch#7}::operator()() const
```

Reviewed By: vkuzo

Differential Revision: D68502879

Pulled By: drisspg
  • Loading branch information
drisspg authored and facebook-github-bot committed Feb 13, 2025
1 parent b0553ce commit 09fb4b9
Show file tree
Hide file tree
Showing 75 changed files with 165 additions and 5,744 deletions.
5 changes: 0 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -872,11 +872,6 @@ cmake_dependent_option(
"USE_CUDA OR USE_ROCM;NOT MSVC"
OFF)

# We are currenlty not using alibi attention for Flash So we disable this
# feature by default We dont currently document this feature because we don't
# Suspect users building from source will need this
add_definitions(-DFLASHATTENTION_DISABLE_ALIBI)

# CAVEAT: Again, Flash Attention2 will error while building for sm52 while Mem
# Eff Attention won't
cmake_dependent_option(
Expand Down
9 changes: 6 additions & 3 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,12 @@ file(GLOB native_quantized_cudnn_hip_cpp "native/quantized/cudnn/hip/*.cpp")
file(GLOB native_utils_cpp "native/utils/*.cpp")

# flash_attention sources
file(GLOB flash_attention_cuda_cu "native/transformers/cuda/flash_attn/*.cu")
file(GLOB flash_attention_cuda_kernels_cu "native/transformers/cuda/flash_attn/kernels/*.cu")
file(GLOB flash_attention_cuda_cpp "native/transformers/cuda/flash_attn/*.cpp")
file(GLOB flash_attention_cuda_kernels_cu ${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cu)
# Flash attention C++ sources
file(GLOB flash_attention_cuda_cpp
"${PROJECT_SOURCE_DIR}/third_party/flash-attention/csrc/flash_attn/src/*.cpp"
"native/transformers/cuda/flash_attn/flash_api.cpp"
)

# flash_attention hip sources
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/native/native_functions.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14862,7 +14862,7 @@
MPS: _scaled_dot_product_attention_math_mps
tags: nondeterministic_seeded

- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
- func: _scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
dispatch:
CUDA: _scaled_dot_product_flash_attention_cuda
NestedTensorCUDA: _scaled_dot_product_flash_attention_nestedtensor_cuda
Expand Down Expand Up @@ -14919,13 +14919,13 @@
CUDA: _scaled_dot_product_cudnn_attention_backward_cuda
tags: nondeterministic_seeded

- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)
- func: _flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)
variants: function
dispatch:
CUDA: _flash_attention_forward
tags: nondeterministic_seeded

- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor)
- func: _flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor rng_state, Tensor unused, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor)
device_check: NoCheck
variants: function
dispatch:
Expand Down
10 changes: 8 additions & 2 deletions aten/src/ATen/native/transformers/cuda/attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@
#ifdef USE_FLASH_ATTENTION
// FlashAttention Specific Imports
#include <ATen/native/transformers/cuda/flash_attn/flash_api.h>
#if !defined(__HIP_PLATFORM_AMD__)
#include <namespace_config.h>
#endif
#endif
#ifdef USE_MEM_EFF_ATTENTION
#ifndef USE_ROCM
Expand Down Expand Up @@ -916,6 +919,7 @@ _flash_attention_forward(
std::optional<Tensor> seqused_k = _seqused_k;
std::optional<at::Tensor> block_table = std::nullopt; // we are not using the block table yet
std::optional<Tensor> alibi_slopes = _alibi_slopes;
const float softcap = 0.0;

const int non_null_window_left = window_size_left.has_value() ? window_size_left.value() : -1;
const int non_null_window_right = window_size_right.has_value() ? window_size_right.value() : -1;
Expand All @@ -939,7 +943,7 @@ _flash_attention_forward(
philox_seed,
philox_offset,
debug_attn_mask) =
pytorch_flash::mha_varlen_fwd(
FLASH_NAMESPACE::mha_varlen_fwd(
query,
key,
value,
Expand All @@ -957,6 +961,7 @@ _flash_attention_forward(
is_causal,
non_null_window_left,
non_null_window_right,
softcap,
return_debug_mask,
std::nullopt /*gen_*/);
} else {
Expand All @@ -969,7 +974,7 @@ _flash_attention_forward(
philox_seed,
philox_offset,
debug_attn_mask) =
pytorch_flash::mha_fwd(
FLASH_NAMESPACE::mha_fwd(
query,
key,
value,
Expand All @@ -980,6 +985,7 @@ _flash_attention_forward(
is_causal,
non_null_window_left,
non_null_window_right,
softcap,
return_debug_mask, /*return_softmax (this is used for testing)*/
std::nullopt);
}
Expand Down
7 changes: 5 additions & 2 deletions aten/src/ATen/native/transformers/cuda/attention_backward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(

// Currently unused args:
std::optional<at::Tensor> alibi_slopes{std::nullopt};
const float softcap = 0.0;

bool determinisitic{false};
auto& ctx = at::globalContext();
Expand All @@ -111,7 +112,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
// in order to determine whether we are using varlen or dense forward
if (cumulative_sequence_length_q.defined()) {
// Varlen forward
auto [dQuery, dKey, dValue, dSoftmax] = pytorch_flash::mha_varlen_bwd(
auto [dQuery, dKey, dValue, dSoftmax] = FLASH_NAMESPACE::mha_varlen_bwd(
contiguous_grad_out,
query,
key,
Expand All @@ -132,13 +133,14 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
is_causal,
non_null_window_left,
non_null_window_right,
softcap,
determinisitic,
philox_seed,
philox_offset);
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue));
} else {
// Dense forward
auto [dQuery, dKey, dValue, dSoftmax] = pytorch_flash::mha_bwd(
auto [dQuery, dKey, dValue, dSoftmax] = FLASH_NAMESPACE::mha_bwd(
contiguous_grad_out,
query,
key,
Expand All @@ -154,6 +156,7 @@ std::tuple<Tensor, Tensor, Tensor> _flash_attention_backward(
is_causal,
non_null_window_left,
non_null_window_right,
softcap,
determinisitic,
philox_seed,
philox_offset);
Expand Down
74 changes: 0 additions & 74 deletions aten/src/ATen/native/transformers/cuda/flash_attn/alibi.h

This file was deleted.

46 changes: 0 additions & 46 deletions aten/src/ATen/native/transformers/cuda/flash_attn/block_info.h

This file was deleted.

96 changes: 0 additions & 96 deletions aten/src/ATen/native/transformers/cuda/flash_attn/dropout.h

This file was deleted.

Loading

0 comments on commit 09fb4b9

Please sign in to comment.