-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
how to compare mamba with flashattention2 #27
Comments
We decided to leave those linear projections out because they are orthogonal to the main "sequence mixing mechanism" (attention vs scan) that is of interest to benchmark. You're right that the comparisons become slightly harder to control (e.g. what model dimension to use is fair?), but we chose a setting that seemed reasonable to us. No matter what, the timings will only be off by a small constant factor with any other "reasonable" setting of dimensions, which is dwarfed by the linear vs quadratic complexity. |
We compared attention time (softmax(QK^T)V) vs scan time, without the linear projection. |
And what datatype did you use? When I try to run scan using fp16, it always raises the error: |
Q, K, V are bf16 for attention. |
it works now, thank you! |
I write a simple script to compare these two component(scan and flashattn2 with causal), and tested it on A100. As instructed, input dim of scan is 4096 and input dim of flashattn is 2048( 32heads * 64 head dim). however, scan is much slower than flashattention2. (fwd: scan is 0.25ms, and flash2 is 0.14ms, fwd+bwd: scan is 1.25ms, flash2 is 0.59ms) Did I make any settings wrong? import torch
import time
test_bwd=False
batch, length, dim, d_state =1, 2048, 2048, 16
from mamba_ssm.ops.selective_scan_interface import SelectiveScanFn
u = torch.randn(batch, dim * 2, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
delta = torch.randn(batch, dim * 2, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
A = torch.randn(dim*2, d_state).to("cuda").requires_grad_(True)
B = torch.randn(batch, 1, d_state, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
C = torch.randn(batch, 1, d_state, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
D = torch.randn(dim*2).to("cuda").requires_grad_(True)
z = torch.randn(batch, dim*2, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
delta_bias = torch.randn(dim*2).to("cuda").requires_grad_(True)
doutssm = torch.randn(batch, dim*2, length).to("cuda").to(torch.bfloat16)
ssm = SelectiveScanFn.apply
for i in range(10):
y = ssm(u, delta, A, B, C, D, z, delta_bias, True)
if test_bwd:
y.backward(doutssm)
torch.cuda.synchronize()
start = time.time()
for i in range(1000):
y = ssm(u, delta, A, B, C, D, z, delta_bias, True)
if test_bwd:
y.backward(doutssm)
torch.cuda.synchronize()
print(time.time() - start)
from flash_attn import flash_attn_func
dim_head = 64
n_heads = dim//dim_head
q = torch.randn(batch, length, n_heads, dim_head).to("cuda").to(torch.bfloat16).requires_grad_(True)
k = torch.randn(batch, length, n_heads, dim_head).to("cuda").to(torch.bfloat16).requires_grad_(True)
v = torch.randn(batch, length, n_heads, dim_head).to("cuda").to(torch.bfloat16).requires_grad_(True)
dout = torch.randn(batch, length, n_heads,dim_head).to("cuda").to(torch.bfloat16)
for i in range(10):
y = flash_attn_func(q, k, v, causal=True)
if test_bwd:
y.backward(dout)
torch.cuda.synchronize()
start = time.time()
for i in range(1000):
y = flash_attn_func(q, k, v, causal=True)
if test_bwd:
y.backward(dout)
torch.cuda.synchronize()
print(time.time() - start) |
Please format your code with triple backticks followed by "python": The appendix of the paper says that the dimension |
Sorry for the format issue. I've re-edited the code above. I also tested input with D=1024, for fwd, it's scan 0.13ms vs flash 0.08ms, for fwd+bwd, it's scan 0.71ms vs flash 0.35 ms. |
Hi, @tridao and @albertfgu, first of all thank you for releasing both FlashAttention (v1 and v2) and Mamba model source codes including the cuda kernels! I too had this issue about not being able to reproduce the benchmarks in particular agains flash attention v2. I tried several settings. (D=768, 1024, 2048) and for N/d_state=16, flash attention was significantly faster than scan. Only at N=4, I start to see the curves reported in the paper. In particular, for N=16 the scan is about 2X slower. Following are the times in ms that I see. It would be immensely useful if you could spare some time to please review the mamba benchmark below or provide few more details to reproduce the benchmark. Thanks @xiayuqing0622 for the starting code.
def benchmark_mamba(batch, head, length, dim_head, d_state):
from mamba_ssm.ops.selective_scan_interface import SelectiveScanFn
from mamba_ssm.ops.selective_scan_interface import selective_scan_cuda
from einops import rearrange, repeat
d_model = dim_head * head
expand = 2
d_inner = d_model * expand
device = "cuda"
# S4D real initialization
A = repeat(
torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
"n -> d n",
d=d_inner,
).contiguous()
A_log = torch.log(A) # Keep A_log in fp32
x = torch.rand(
(batch, d_inner, length), device=device, dtype=torch.bfloat16
).requires_grad_(True)
z = torch.rand(
(batch, d_inner, length), device=device, dtype=torch.bfloat16
).requires_grad_(True)
delta = torch.rand(
(batch, d_inner, length), device=device, dtype=torch.bfloat16
).requires_grad_(True)
delta_bias = torch.randn(d_inner).to("cuda").requires_grad_(True)
A = -torch.exp(A_log.float()) # (d_inner, d_state)
B = (
torch.randn(batch, 1, d_state, length)
.to("cuda")
.to(torch.bfloat16)
.requires_grad_(True)
)
C = (
torch.randn(batch, 1, d_state, length)
.to("cuda")
.to(torch.bfloat16)
.requires_grad_(True)
)
D = torch.ones(d_inner, device=device) # Keep in fp32
delta_softplus = True
ms = triton.testing.do_bench(
lambda: selective_scan_cuda.fwd(
x, delta, A, B, C, D, z, delta_bias, delta_softplus
),
warmup=100,
)
return ms The full code is below but please feel free to ignore the rest. Here is the code import itertools
from math import sqrt
import pandas
import torch
from tqdm import tqdm
import triton
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func
def get_inputs(B, H, L, E=64, ret_padding_mask=False, dtype=torch.float32):
q = torch.rand((B, H, L, E), device="cuda", dtype=dtype)
k = torch.rand((B, H, L, E), device="cuda", dtype=dtype)
v = torch.rand((B, H, L, E), device="cuda", dtype=dtype)
input_lengths = torch.randint(1, L, (B,), device=q.device).long()
input_lengths[-1] = L
padding_mask = torch.zeros((B, L), dtype=q.dtype, device=q.device)
padding_mask[
(
torch.arange(padding_mask.shape[0], device=padding_mask.device),
input_lengths - 1,
)
] = 1
padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
if not ret_padding_mask:
padding_mask = None
return (q, k, v), padding_mask
def flash_attn_forward(queries, keys, values, padding_mask=None):
qkv = torch.stack([queries, keys, values], dim=2)
qkv = qkv.permute(0, 3, 2, 1, 4)
B, T, _, H, D = qkv.shape
scale = 1.0 / sqrt(D)
if padding_mask is not None:
# unpad_input expectes True to correspond to valid indices and False to invalid
qkv, indices, cu_q_lens, max_s = unpad_input(qkv, ~padding_mask)
packed_res = flash_attn_varlen_qkvpacked_func(
qkv,
cu_q_lens,
max_s,
dropout_p=0.0,
softmax_scale=scale,
causal=False,
alibi_slopes=None,
deterministic=False,
)
res = pad_input(packed_res, indices, B, T)
res = res.transpose(1, 2)
else:
res = flash_attn_qkvpacked_func(
qkv,
dropout_p=0.0,
softmax_scale=scale,
causal=False,
alibi_slopes=None,
deterministic=False,
)
res = res.transpose(1, 2) # B x T x H x D -> B x H x T x D
return res
def benchmark_flash(q, k, v, padding_mask):
dim_E = q.shape[-1]
H = q.shape[1]
E = dim_E * H
ms = triton.testing.do_bench(
lambda: flash_attn_forward(q, k, v, padding_mask=padding_mask), warmup=100
)
return ms
if __name__ == "__main__":
batch_sizes = [16]
heads = [12, 16, 32]
time_steps = [1000, 1600, 3200, 6400]
get_padding_masks = [True, False]
d_states = [2, 4, 8, 16]
dtypes = [torch.bfloat16]
E = 64
results = []
for B, H, L, pm, dtype in tqdm(
itertools.product(batch_sizes, heads, time_steps, get_padding_masks, dtypes)
):
(q, k, v), padding_mask = get_inputs(
B, H, L, E=64, ret_padding_mask=pm, dtype=dtype
)
ms = benchmark_flash(q, k, v, padding_mask)
results.append(
{
"name": "flash",
"batch_size": B,
"nheads": H,
"seq_len": L,
"dim": H * E,
"padding": pm,
"dtype": dtype,
"ms": ms,
}
)
for B, H, L, pm, d_state, dtype in tqdm(
itertools.product(
batch_sizes, heads, time_steps, get_padding_masks, d_states, dtypes
)
):
(q, k, v), padding_mask = get_inputs(
B, H, L, E=64, ret_padding_mask=pm, dtype=dtype
)
ms = benchmark_mamba(B, H, L, E, d_state)
results.append(
{
"name": f"mamba-{d_state}",
"batch_size": B,
"nheads": H,
"seq_len": L,
"dim": H * E,
"padding": pm,
"dtype": dtype,
"ms": ms,
}
)
df = pandas.DataFrame(results)
piv = df.pivot(
columns="name",
values="ms",
index=["dtype", "padding", "batch_size", "nheads", "seq_len"],
)
print(piv.sort_index().round(3)) |
Try |
@tridao |
HI @apoorv2904 , are you able to reproduce the results? If so could you please share how you reproduced the result? |
In your paper, you mentioned that mamba scan is faster than flashattention2.
Does it mean comparing
mamba/mamba_ssm/ops/selective_scan_interface.py
Line 14 in 0131c1e
The inputs of these two modules are different, is this comparation fair? Or the preprocessing(compute q, k, v in flashattention; compute A,B,C,D,delta in mamba scan) need to be be taken into account?
The text was updated successfully, but these errors were encountered: