-
Notifications
You must be signed in to change notification settings - Fork 98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
About the speed test #8
Comments
Hello, thank you for taking interest in my work ! I just re-ran the script I used to generate the graph showing the training speed time over
(A100 80GB) Here is the code I use for the benchmark :
The main difference between your benchmark and mine is that I benchmark a whole Mamba block while you benchmark only a part of it, the selective scan. There are some other computations taking place in the Mamba block like |
Thank you for your quick reply. In my experiments (VMamba), the selective_scan_fwd and bwd takes 80%+ of the training time. So your observations are very valuable to me. |
I think I have got the answer partially: # test_mamba_a(16) # 34.6ms in 4090
# test_mamba_a(64) # 34ms ien 4090
# test_mamba_a(128) # 45.3ms in 4090
# test_mamba_(16) # 17.4ms in 4090
# test_mamba_(64) # 28ms in 4090
# test_mamba_(128) # 41.8ms in 4090
# x = torch.randn(batch, length, dim).to("cuda")
x = torch.randn(batch, length, dim, device=torch.device("cuda")) It is very time consuming transporting data from host memory to device side, so I think maybe the time gap between different d_states have been covered by the operation in your experiments. Looking forward to your experiments in A100! |
Hello, indeed, when creating
That is much more logical considering what happens in the code, thank you for pointing it out ! However, considering what happens in practice, it's usual to move data from host to device before feeding it to the model no ?
So in practice, I guess rising
|
Yes, you are right. ...
start_time = time.time()
N = 500
data = torch.randn(batch, length, dim)
for _ in range(N):
x = data.to('cuda', non_blocking=True)
y_cuda = model(x)
loss = torch.norm(y_cuda)
...
... results in: # test_mamba2_(16) # 18.54ms in 4090
# test_mamba2_(64) # 30ms in 4090
# test_mamba2_(128) # 44.4ms in 4090 and for bigger batch size (bs=48, resulting in bigger gridDim with limited Stream Multiprocessor) ...
start_time = time.time()
# N = 100
data = torch.randn(batch, length, dim) # batch=48
for _ in range(N):
x = data.to('cuda', non_blocking=True)
y_cuda = model(x)
loss = torch.norm(y_cuda)
...
... results in: # test_mamba2_(16, 48, 100) # 348ms in 4090; 21G
# test_mamba2_(64, 48, 100) # 638ms in 4090; 21G
# test_mamba2_(128, 48, 100) # 1028ms in 4090; 22G I think the results are more obvious now😂 |
Yes, it's becoming clearer !
which is different from what you observe ! But increasing the batch size makes the times differ for the different I tried increasing the batch size but lowering other parameters (like length and dim) and once again the
I will update the Performances on the repo. Thank you ! |
I profile the script using nvidia nsight system. From my observation, copying data host to device only spend a small amount of time. It seems that torch.randn() on cpu spends much time. x = torch.ones(batch, length, dim).to("cuda")
# x = torch.randn(batch, length, dim).to("cuda") and
result on A100 # x = torch.ones(batch, length, dim).to("cuda")
B=3, L=4096, d_model=192, d_state=16, time (ms)=4.19
B=3, L=4096, d_model=192, d_state=64, time (ms)=6.51
B=3, L=4096, d_model=192, d_state=128, time (ms)=10.46
# x = torch.ones(batch, length, dim, device="cuda")
B=3, L=4096, d_model=192, d_state=16, time (ms)= 3.15
B=3, L=4096, d_model=192, d_state=64, time (ms)=5.77
B=3, L=4096, d_model=192, d_state=128, time (ms)=9.16
# x = torch.randn(batch, length, dim).to("cuda")
B=3, L=4096, d_model=192, d_state=16, time (ms)= 13.2
B=3, L=4096, d_model=192, d_state=64, time (ms)=13.2
B=3, L=4096, d_model=192, d_state=128, time (ms)=13.2
# x = torch.randn(batch, length, dim, device="cuda")
B=3, L=4096, d_model=192, d_state=16, time (ms)= 3.16
B=3, L=4096, d_model=192, d_state=64, time (ms)=5.74
B=3, L=4096, d_model=192, d_state=128, time (ms)=9.15 |
@smallscientist1 Thank you for your observation. It seems that generating And that may explain why the speed for different To confirm the hypnosis above, we need to record gpu time and cpu time separately. By the way, have you ever compared the following 2 ways? 1. x = torch.randn(...); benchmark(x = x.to('cuda'));
2. benchmark(x = torch.randn(device=torch.device('cuda'))); |
@MzeroMiko Thanks for your suggestions. I use 2 scripts to give a more obvious conclusion. script 1(compute torch.randn only once): import torch
import time
from mamba_ssm import Mamba
batch, length, dim = 3, 4096, 192
d_state = 16
torch.manual_seed(1)
model = Mamba(
d_model=dim, # Model dimension d_model
d_state=d_state, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
optim = torch.optim.AdamW(model.parameters(), lr=3e-3)
start_time = time.time()
N = 500
x_ = torch.randn(batch, length, dim)
for _ in range(N):
x = x_.to("cuda")
y_cuda = model(x)
loss = torch.norm(y_cuda)
optim.zero_grad()
loss.backward()
optim.step()
end_time = time.time()
res = (end_time-start_time)/N
print(f"B={batch}, L={length}, d_model={dim}, d_state={model.d_state}, time (ms)={res*1000}") script 2(compute torch.randn in loop) import torch
import time
from mamba_ssm import Mamba
batch, length, dim = 3, 4096, 192
d_state = 16
torch.manual_seed(1)
model = Mamba(
d_model=dim, # Model dimension d_model
d_state=d_state, # SSM state expansion factor
d_conv=4, # Local convolution width
expand=2, # Block expansion factor
).to("cuda")
optim = torch.optim.AdamW(model.parameters(), lr=3e-3)
start_time = time.time()
N = 500
for _ in range(N):
x = torch.randn(batch, length, dim).to("cuda")
y_cuda = model(x)
loss = torch.norm(y_cuda)
optim.zero_grad()
loss.backward()
optim.step()
end_time = time.time()
res = (end_time-start_time)/N
print(f"B={batch}, L={length}, d_model={dim}, d_state={model.d_state}, time (ms)={res*1000}") result # script 1
B=3, L=4096, d_model=192, d_state=16, time (ms)= 3.7
# script 2
B=3, L=4096, d_model=192, d_state=16, time (ms)= 13.1 Script1 only compute torch.randn once and benchmarks the memcpy and mamba kernel while script2 benchmarks the torch.randn on cpu, memcpy host to device and mamba kernel. The result is 3.7 vs 13.1. |
Thank you for sharing your fantastic work.
We have noticed the image that with rising the dimension of
d_state
, the mamba's time occupation doesn't rise.However, we found in code that writes (selective_scan_fwd_kernel.cuh#163):
which shows a for loop with related to state_idx that reads from HBM to shared memory.
Then I tested the speed again and finds that with the d_state rises, the time occupation of mamba rises linearly, which is aligned with the code.
So what did I miss?
The text was updated successfully, but these errors were encountered: