Skip to content

Commit

Permalink
Add unit tests for flash attention jax. Updated the README.
Browse files Browse the repository at this point in the history
  • Loading branch information
nshepperd committed Feb 7, 2024
1 parent 88b5057 commit a04714d
Show file tree
Hide file tree
Showing 19 changed files with 182 additions and 6,205 deletions.
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ To install: TODO

Interface: `src/flash_attn_jax/flash.py`

```py
from flash_attn_jax import flash_mha

flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1))
```

Accepts q,k,v with shape `[n, l, h, d]`, and returns `[n, l, h, d]`. `softmax_scale` is the
multiplier for the softmax, defaulting to `1/sqrt(d)`. Set window_size
to positive values for sliding window attention.

FlashAttention-2 currently supports:
1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
Expand Down
27 changes: 14 additions & 13 deletions csrc/flash_attn/mha_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ mha_bwd(cudaStream_t stream, void **buffers, const char* opaque, size_t opaque_l
const int seqlen_q = args.l;
const int num_heads = args.h;
const int head_size_og = args.d; //dout.size(3);
const int head_size = args.d; //sizes[3];
const int head_size = args.d + (8 - head_size_og%8) % 8; //sizes[3];
const int seqlen_k = args.l_k;
const int num_heads_k = args.h_k; //k.size(2);
CHECK(batch_size > 0, "batch size must be positive");
Expand Down Expand Up @@ -237,12 +237,13 @@ mha_bwd(cudaStream_t stream, void **buffers, const char* opaque, size_t opaque_l
// }

// at::Tensor dout_padded;
if (head_size_og % 8 != 0) {
CHECK(false, "can't pad");
// dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
} else {
// dout_padded = dout;
}

// if (head_size_og % 8 != 0) {
// CHECK(false, "can't pad");
// // dout_padded = torch::nn::functional::pad(dout, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
// } else {
// // dout_padded = dout;
// }

// bool loop = seqlen_k > blocksize_c;
// TODO: change later, for now set to true for simplicity
Expand Down Expand Up @@ -365,12 +366,12 @@ mha_bwd(cudaStream_t stream, void **buffers, const char* opaque, size_t opaque_l
// at::sum_out(dk, at::reshape(dk_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
// at::sum_out(dv, at::reshape(dv_expanded, {batch_size, seqlen_k, num_heads_k, num_heads / num_heads_k, head_size}), {3});
}
if (head_size_og % 8 != 0) {
CHECK(false, "can't slice");
// dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
// dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
// dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
}
// if (head_size_og % 8 != 0) {
// CHECK(false, "can't slice");
// // dq = dq.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
// // dk = dk.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
// // dv = dv.index({"...", torch::indexing::Slice(torch::indexing::None, head_size_og)});
// }

// return { dq, dk, dv, softmax_d };
}
Expand Down
22 changes: 11 additions & 11 deletions csrc/flash_attn/mha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,17 +114,17 @@ void mha_fwd(cudaStream_t stream, void **buffers, const char* opaque, size_t opa
// CHECK_SHAPE(v, batch_size, seqlen_k, num_heads_k, head_size_og);

// at::Tensor q_padded, k_padded, v_padded;
void *q_padded, *k_padded, *v_padded;
if (head_size_og % 8 != 0) {
CHECK(false, "can't pad");
// q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
// k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
// v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
} else {
q_padded = q;
k_padded = k;
v_padded = v;
}
void *q_padded=q, *k_padded=k, *v_padded=v;
// if (head_size_og % 8 != 0) {
// // CHECK(false, "can't pad");
// // q_padded = torch::nn::functional::pad(q, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
// // k_padded = torch::nn::functional::pad(k, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
// // v_padded = torch::nn::functional::pad(v, torch::nn::functional::PadFuncOptions({0, 8 - head_size_og % 8}));
// } else {
// q_padded = q;
// k_padded = k;
// v_padded = v;
// }

// at::Tensor out;
// if (out_.has_value()) {
Expand Down
78 changes: 53 additions & 25 deletions src/flash_attn_jax/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,39 @@
_flash_mha_bwd_p.def_impl(partial(xla.apply_primitive, _flash_mha_bwd_p))


def flash_mha_fwd(q, k, v, softmax_scale=None):
return _flash_mha_fwd_p.bind(q, k, v, softmax_scale=softmax_scale)

def flash_mha_bwd(dout, q, k, v, out, lse, softmax_scale=None):
return _flash_mha_bwd_p.bind(dout, q, k, v, out, lse, softmax_scale=softmax_scale)
def flash_mha_fwd(q, k, v, softmax_scale, is_causal, window_size):
d = q.shape[-1]
assert len(q.shape) == 4
assert d == k.shape[-1]
assert d == v.shape[-1]
if d % 8 != 0:
# We need padding.
padding = [(0,0),(0,0),(0,0),(0, 8 - d%8)]
q = jnp.pad(q, padding)
k = jnp.pad(k, padding)
v = jnp.pad(v, padding)
out, lse = _flash_mha_fwd_p.bind(q, k, v, softmax_scale=softmax_scale, d_og=d, is_causal=is_causal, window_size=window_size)
if d % 8 != 0:
out = out[..., :d]
return out, lse

def flash_mha_bwd(dout, q, k, v, out, lse, softmax_scale, is_causal, window_size):
d = q.shape[-1]
assert len(q.shape) == 4
assert d == k.shape[-1]
assert d == v.shape[-1]
if d % 8 != 0:
# We need padding.
padding = [(0,0),(0,0),(0,0),(0, 8 - d%8)]
q = jnp.pad(q, padding)
k = jnp.pad(k, padding)
v = jnp.pad(v, padding)
out = jnp.pad(out, padding)
dout = jnp.pad(dout, padding)
dq, dk, dv = _flash_mha_bwd_p.bind(dout, q, k, v, out, lse, softmax_scale=softmax_scale, d_og=d, is_causal=is_causal, window_size=window_size)
if d % 8 != 0:
return dq[...,:d], dk[...,:d], dv[...,:d]
return dq, dk, dv

# ==== CUDA lowerings ====

Expand All @@ -43,7 +71,7 @@ def row_major(shape):
return range(len(shape)-1, -1, -1)
return [row_major(shape) for shape in shapes]

def _flash_mha_fwd_cuda_lowering(ctx, q, k, v, softmax_scale=None):
def _flash_mha_fwd_cuda_lowering(ctx, q, k, v, softmax_scale=None, d_og=None, is_causal=False, window_size=None):
# print(type(q), dir(q), q.type)
q_type = ir.RankedTensorType(q.type)
q_shape = q_type.shape
Expand All @@ -68,11 +96,11 @@ def _flash_mha_fwd_cuda_lowering(ctx, q, k, v, softmax_scale=None):
opaque = flash_api.make_flash_mha_fwd_args(
0.0, # p_dropout
softmax_scale,
False, # is_causal
-1, # window_size_left
-1, # window_size_right
is_causal, # is_causal
window_size[0], # window_size_left
window_size[1], # window_size_right
False, # return_softmax
n, l, h, d,
n, l, h, d_og or d,
lk, hk,
flash_api.BF16 if type(out_element_type) == ir.BF16Type else flash_api.FP16,
0)
Expand All @@ -92,7 +120,7 @@ def _flash_mha_fwd_cuda_lowering(ctx, q, k, v, softmax_scale=None):
platform="gpu",
)

def _flash_mha_bwd_cuda_lowering(ctx, dout, q, k, v, out, lse, softmax_scale=None):
def _flash_mha_bwd_cuda_lowering(ctx, dout, q, k, v, out, lse, softmax_scale=None, d_og=None, is_causal=None, window_size=None):
dout_type = ir.RankedTensorType(dout.type).element_type
q_type = ir.RankedTensorType(q.type).element_type
k_type = ir.RankedTensorType(k.type).element_type
Expand Down Expand Up @@ -129,11 +157,11 @@ def _flash_mha_bwd_cuda_lowering(ctx, dout, q, k, v, out, lse, softmax_scale=Non
opaque = flash_api.make_flash_mha_bwd_args(
0.0, # p_dropout
softmax_scale,
False, # is_causal
-1, # window_size_left
-1, # window_size_right
is_causal, # is_causal
window_size[0], # window_size_left
window_size[1], # window_size_right
False, # deterministic
n, lq, hq, d,
n, lq, hq, d_og or d,
lk, hk,
flash_api.BF16 if type(q_type) == ir.BF16Type else flash_api.FP16,
0)
Expand All @@ -155,7 +183,7 @@ def _flash_mha_bwd_cuda_lowering(ctx, dout, q, k, v, out, lse, softmax_scale=Non

# ==== Abstract evaluation rules ====

def _flash_mha_fwd_abstract(q, k, v, softmax_scale=None):
def _flash_mha_fwd_abstract(q, k, v, softmax_scale=None, d_og=None, is_causal=None, window_size=None):
q_dtype = dtypes.canonicalize_dtype(q.dtype)
k_dtype = dtypes.canonicalize_dtype(k.dtype)
v_dtype = dtypes.canonicalize_dtype(v.dtype)
Expand All @@ -169,7 +197,7 @@ def _flash_mha_fwd_abstract(q, k, v, softmax_scale=None):
_flash_mha_fwd_p.def_abstract_eval(_flash_mha_fwd_abstract)


def _flash_mha_bwd_abstract(dout, q, k, v, out, lse, softmax_scale=None):
def _flash_mha_bwd_abstract(dout, q, k, v, out, lse, softmax_scale=None, d_og=None, is_causal=None, window_size=None):
dout_dtype = dtypes.canonicalize_dtype(dout.dtype)
q_dtype = dtypes.canonicalize_dtype(q.dtype)
k_dtype = dtypes.canonicalize_dtype(k.dtype)
Expand Down Expand Up @@ -227,19 +255,19 @@ def custom_vjp(cls, nondiff_argnums=()):
# gets placed at the front of the argument list in bwd.
@partial(custom_vjp, nondiff_argnums=(3,))
class _flash_mha_vjp:
def base(q,k,v,softmax_scale):
return flash_mha_fwd(q,k,v, softmax_scale=softmax_scale)[0]
def fwd(q,k,v,softmax_scale):
out, lse = flash_mha_fwd(q,k,v, softmax_scale=softmax_scale)
def base(q,k,v,config):
return flash_mha_fwd(q,k,v, **config)[0]
def fwd(q,k,v,config):
out, lse = flash_mha_fwd(q,k,v, **config)
return out, (q,k,v,out,lse)
def bwd(softmax_scale, pack, dout):
def bwd(config, pack, dout):
(q,k,v,out,lse) = pack
dq, dk, dv = flash_mha_bwd(dout, q, k, v, out, lse, softmax_scale=softmax_scale)
dq, dk, dv = flash_mha_bwd(dout, q, k, v, out, lse, **config)
return (dq,dk,dv)

# ==== Frontend ====

def flash_mha(q,k,v,softmax_scale=None):
def flash_mha(q,k,v,softmax_scale=None, is_causal=False, window_size=(-1,-1)):
"""Flash attention.
softmax_scale defaults to 1/sqrt(d) and must be a python float if
Expand All @@ -249,5 +277,5 @@ def flash_mha(q,k,v,softmax_scale=None):
if softmax_scale is None:
softmax_scale = 1/math.sqrt(q.shape[-1])
assert type(softmax_scale) is float
o = _flash_mha_vjp(q,k,v,softmax_scale=softmax_scale)
o = _flash_mha_vjp(q,k,v,dict(softmax_scale=softmax_scale, is_causal=is_causal, window_size=window_size))
return o
Loading

0 comments on commit a04714d

Please sign in to comment.