-
-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DeepseekMoE support with Fused MoE kernel #2453
Conversation
@zwd003 Thanks for fixing my last PR! But have you seen that it seems no speed boost after adapting expert parallelism. |
@esmeetu Can you help test and review this PR? |
Hi, @zwd003, I refactor based on this PR . Please refer to #2467. |
in my setting(8 A100 40g, tp=8, max_tokens=256, number of prompts = 256, max_batch_size = 256, with average input tokens per prompt 168), this implementation has faster speed. |code |enforce_eager. |speed(it/s)|
|baseline(original code) |True |1.87|
|this(20240116) |True |7.04|
|this(fused_moe) |False |10.73| |
i developed a fused MOE kernel that achieves faster speeds(different from #2293). the code is ready to review @esmeetu @zhuohan123. |
@zwd003 Running on T4 GPU with float16 not working. RuntimeError: Internal Triton PTX codegen error: |
now it can run with fp16 |
LGTM! |
in my setting(TP=2, A100 40g pcie, max_batch_size = 256, max_tokens=256, prompts=256 with average input length 166), llama2 13b cost 51s, deepseek16b-moe cost 25s. For smaller models, a smaller tp yields better acceleration effects compare to dense model. see benchmarks above for more results |
Pretty cool implementation! I do feel this PR will perform better result comparing to my implementation #2293. Wondering if you have benchmarked that yet? |
Hi, thank you very much for all the great work! |
MoE requires more optimization to achieve the same effect as dense models (dense models are mostly dense matrix multiplication, and cublas can almost reach the limit of hardware computation speed). For instance, the current mistral8x7b is also slower than the 14b dense model. Additionally, MoE models involve many more kernel computations, such as gate (softmax) and topk. If these kernels continue to be fused together, there is potential for further speed improvements. |
This looks really good! |
this also be applicable to mistral8x7b. Only need to modify the parameter names in Mistral (packing parameters from different experts together, you can se |
@zwd003 I have a question related to quantization. How can we apply this optimization to quantized models? Do we need to dequantize weights before/during running the kernel to achieve the speedup? |
it needs a new kernel supporting quant matmul(we need to rewrite the kernel in cuda/cpp, but the main idea is not changed, that is, computing each block for the corresponding expert) |
I do believe we have Triton kernels for both GPTQ and AWQ. Perhaps you would need to create separate quantized Triton kernels based on the linked kernels for quantized matmul. If this is of interest to DeepSeek, I could implement quantization in AWQ of the DeepSeek-MoE model. |
intermediate_size: int, | ||
hidden_act: str, | ||
linear_method: Optional[LinearMethodBase] = None, | ||
reduce_results=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was a helpful way to remove ambiguity if we define another DeepseekExpertMLP
without reducing results. Then we can remove this reduce_results
parameter. If we keep this, adding a type for reduce_results
looks nicer. Both choices are ok.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, it's right, i have fix it
Even with the current code, I found that it is indeed 1.6 times faster than dense-7b. The performance bottleneck in the test results from the table is not in the model's computation. Please see the new test results. |
That's so great to hear, thanks! |
0cfc260
to
0bb745f
Compare
Unfortunately, there seem to be some correctness issues with the kernel (this was prompted by the user feedback #2542 (comment)). It happens in all kinds of configurations in smaller ways but can be reproduced in a pretty major way with the following diff: diff --git a/tests/kernels/test_fused_moe.py b/tests/kernels/test_fused_moe.py
index f68e84f4f9..1b8e6321e6 100644
--- a/tests/kernels/test_fused_moe.py
+++ b/tests/kernels/test_fused_moe.py
@@ -22,11 +22,11 @@ def torch_moe(a, w1, w2, topk_weight, topk_ids):
topk_weight.view(B, -1, 1)).sum(dim=1)
-@pytest.mark.parametrize("m", [512, 222, 33, 1])
-@pytest.mark.parametrize("n", [2048, 256, 1024])
-@pytest.mark.parametrize("k", [128, 511, 1024])
-@pytest.mark.parametrize("e", [8, 64])
-@pytest.mark.parametrize("topk", [2, 6])
+@pytest.mark.parametrize("m", [1])
+@pytest.mark.parametrize("n", [2048, 256, 1024, 8192])
+@pytest.mark.parametrize("k", [4096])
+@pytest.mark.parametrize("e", [8])
+@pytest.mark.parametrize("topk", [2])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
def test_fused_moe(
m: int,
@@ -46,4 +46,4 @@ def test_fused_moe(
triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False)
torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids)
- assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0)
+ assert torch.allclose(triton_output, torch_output, atol=1e-3, rtol=0), torch.max(abs(triton_output - torch_output)) This gives an error of |
This issue might be due to numerical precision; after switching to fp32, the difference is very small. # assert hidden_states.dtype in [torch.float16, torch.bfloat16]# disable type assert
accumulator += tl.dot(a, b, allow_tf32=False) # disable tf32 in in kernel
# accumulator = accumulator.to(compute_type) # disable translation to bf16 or fp16 test_fused_moe(1, 8192, 4096, 8, 2, torch.float32) outputs:
|
@zwd003 You are right, thanks a lot for checking. I only set the dtype to float32 OR used the |
size_t numel) { | ||
const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); | ||
const size_t start_idx = threadIdx.x * tokens_per_thread; | ||
__shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Actually, this part doesn't need to be static. Shared memory size can be configured dynamically at the kernel launch time. However, I think we can fix this in a later PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When will this merge be pushed to the latest openai docker image? Thank you so much for the work! |
In my tests, https://github.com/casper-hansen/AutoAWQ/blob/mixtral_fused/tests/test_fused_moe.py |
Thanks @casper-hansen, I'm digging into this some more and I'm also planning to add a test to the repo about it :) In the test you posted (which is very nice btw), I see that all the |
I figured out what is going on now I think. There were two adaptations I needed to make so your script can be adapted to the Mixtral MOE: Load the gate, and also set import torch
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
import time
from vllm.model_executor.models.mixtral import MixtralMoE
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
config = MixtralConfig()
block = MixtralSparseMoeBlock(config).float().to("cuda")
fused = MixtralMoE(
num_experts=config.num_local_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
)
fused.gate.linear_weights["weight"][:] = block.gate.weight.data
for i in range(config.num_local_experts):
fused.ws[i][:] = torch.cat((
block.experts[i].w1.weight.data,
block.experts[i].w3.weight.data,
), dim=0).to("cuda")
fused.w2s[i][:] = block.experts[i].w2.weight.data
def _run_profile(fn, inputs, rounds=2):
start_time = time.perf_counter()
torch.cuda.synchronize()
for _ in range(rounds):
states, router_logits = fn(inputs)
torch.cuda.synchronize()
end_time = time.perf_counter()
return (end_time - start_time) / rounds, states, router_logits
# [batch_size, seq_len, hidden_dim]
inputs = torch.randn((1, 64, config.hidden_size)).to("cuda")
block_time, states_block, router_block = _run_profile(block.forward, inputs)
fused_time, states_fused, router_fused = _run_profile(fused.forward, inputs)
print(block_time, fused_time, block_time / fused_time)
print("states_fused", states_fused)
print("states_block", states_block)
print("diff1", (states_fused - states_block).mean().abs())
print("diff2", (states_fused - states_block).abs().max()) And this is the diff to the repo (mostly to make sure the MoE layer can run in the same process and also make sure it doesn't use lower precision tensor core arithmetic): diff --git a/vllm/model_executor/layers/fused_moe.py b/vllm/model_executor/layers/fused_moe.py
index 998062d82d..99d8de7ccb 100644
--- a/vllm/model_executor/layers/fused_moe.py
+++ b/vllm/model_executor/layers/fused_moe.py
@@ -105,7 +105,7 @@ def fused_moe_kernel(
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
# We accumulate along the K dimension.
- accumulator += tl.dot(a, b)
+ accumulator += tl.dot(a, b, allow_tf32=False)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
@@ -235,7 +235,7 @@ def fused_moe(hidden_states: torch.Tensor,
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
- assert hidden_states.dtype in [torch.float16, torch.bfloat16]
+ assert hidden_states.dtype in [torch.float16, torch.bfloat16, torch.float32]
M, _ = hidden_states.shape
E, N, _ = w1.shape
diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py
index f36c35fd27..480de2b8bf 100644
--- a/vllm/model_executor/models/mixtral.py
+++ b/vllm/model_executor/models/mixtral.py
@@ -72,7 +72,7 @@ class MixtralMoE(nn.Module):
params_dtype: Optional[torch.dtype] = None,
):
super().__init__()
- tp_size = get_tensor_model_parallel_world_size()
+ tp_size = 1 # get_tensor_model_parallel_world_size()
self.num_total_experts = num_experts
self.top_k = top_k
self.hidden_size = hidden_size
@@ -93,13 +93,15 @@ class MixtralMoE(nn.Module):
2 * self.intermediate_size,
self.hidden_size,
device="cuda",
- dtype=self.params_dtype))
+ dtype=self.params_dtype),
+ requires_grad=False)
self.w2s = nn.Parameter(
torch.empty(self.num_total_experts,
self.hidden_size,
self.intermediate_size,
device="cuda",
- dtype=self.params_dtype))
+ dtype=self.params_dtype),
+ requires_grad=False)
set_weight_attrs(self.ws, {
"weight_loader": self.weight_loader,
@@ -139,13 +141,13 @@ class MixtralMoE(nn.Module):
self.w2s,
routing_weights,
selected_experts,
- inplace=True)
+ inplace=False)
- final_hidden_states = tensor_model_parallel_all_reduce(
- final_hidden_states)
+ # final_hidden_states = tensor_model_parallel_all_reduce(
+ # final_hidden_states)
return final_hidden_states.view(batch_size, sequence_length,
- hidden_size)
+ hidden_size), None
class MixtralAttention(nn.Module):
@@ -160,7 +162,7 @@ class MixtralAttention(nn.Module):
sliding_window: Optional[int] = None) -> None:
super().__init__()
self.hidden_size = hidden_size
- tp_size = get_tensor_model_parallel_world_size()
+ tp_size = 1 # get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size With those modifications, the results are very accurate even if I set the number of rounds to a high value, for example for rounds = 100, I'm getting
I'll convert this to a test that can be committed into the repo next! Thanks for looking into this, I'm similarly interested in making sure the model quality is as high as possible :) |
@casper-hansen Unit test added in #2677 |
Co-authored-by: roy <[email protected]>
Co-authored-by: roy <[email protected]>
Co-authored-by: roy <[email protected]>
I'm wondering if this commit really applies expert parallelism? |
Adding support for DeepseekMoE as described in here.
This work was partly done by @esmeetu and DeepSeek-AI
We have fixed some bugs in the @esmeetu's code and added support for expert parallelism and fused moe kernel.
Test code:
Ouput:
update 2024.01.19
Performance Benchmarking for Fused MoE Improvements
This PR introduces significant performance enhancements when compared to the current method in Mistral and other baseline methods. Below is a summary of the benchmarks conducted:
Benchmark Details:
Results Summary:
*Updated on 2024-01-20: Added comparison with PR #2293.
Updated on 2024.01.21: Implement the
align_block_size
function in C++ to achieve a 10% performance improvement. Now deepseek-moe16b is possible to achieve speeds almost identical to 7b dense model.Updated on 2024.01.22: I found that it has greatly exceeded the speed of llama2-7b, with deepseekmoe-16b at 16587.82 tokens/s and llama2-7b at 10978.67 tokens/s. The speed bottleneck in the results from the table above(8s vs 9s) is not due to the model's computation speed. The following code was used to test this:
Future Works
I believe it is possible to achieve higher performance and surpass the speed of the 7b-dense model; we might need to do the following things:
those works may be done by the community in the future(in another pr)