forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add unit test for Mixtral MoE layer (vllm-project#2677)
- Loading branch information
Showing
5 changed files
with
119 additions
and
55 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
"""Tests for the MOE layers. | ||
Run `pytest tests/kernels/test_moe.py`. | ||
""" | ||
|
||
import pytest | ||
import torch | ||
|
||
from transformers import MixtralConfig | ||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock | ||
|
||
from vllm.model_executor.layers.fused_moe import fused_moe | ||
from vllm.model_executor.layers.activation import SiluAndMul | ||
from vllm.model_executor.models.mixtral import MixtralMoE | ||
|
||
|
||
def torch_moe(a, w1, w2, topk_weight, topk_ids): | ||
B, D = a.shape | ||
a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) | ||
out = torch.zeros(B * topk_ids.shape[1], | ||
w2.shape[1], | ||
dtype=a.dtype, | ||
device=a.device) | ||
topk_ids = topk_ids.view(-1) | ||
topk_weight = topk_weight.view(-1) | ||
for i in range(w1.shape[0]): | ||
mask = topk_ids == i | ||
if mask.sum(): | ||
out[mask] = SiluAndMul()( | ||
a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) | ||
return (out.view(B, -1, w2.shape[1]) * | ||
topk_weight.view(B, -1, 1)).sum(dim=1) | ||
|
||
|
||
@pytest.mark.parametrize("m", [512, 222, 33, 1]) | ||
@pytest.mark.parametrize("n", [2048, 256, 1024]) | ||
@pytest.mark.parametrize("k", [128, 511, 1024]) | ||
@pytest.mark.parametrize("e", [8, 64]) | ||
@pytest.mark.parametrize("topk", [2, 6]) | ||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) | ||
def test_fused_moe( | ||
m: int, | ||
n: int, | ||
k: int, | ||
e: int, | ||
topk: int, | ||
dtype: torch.dtype, | ||
): | ||
a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 | ||
w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 | ||
w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 | ||
|
||
score = torch.randn((m, e), device='cuda', dtype=dtype) | ||
score = torch.softmax(score, dim=-1) | ||
topk_weight, topk_ids = torch.topk(score, topk) | ||
|
||
triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False) | ||
torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids) | ||
assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) | ||
|
||
|
||
@pytest.mark.parametrize("dtype", | ||
[torch.float32, torch.float16, torch.bfloat16]) | ||
@torch.inference_mode() | ||
def test_mixtral_moe(dtype: torch.dtype): | ||
"Make sure our Mixtral MoE implementation agrees with the one from huggingface." | ||
|
||
# Instantiate our and huggingface's MoE blocks | ||
config = MixtralConfig() | ||
hf_moe = MixtralSparseMoeBlock(config).to(dtype).to("cuda") | ||
vllm_moe = MixtralMoE( | ||
num_experts=config.num_local_experts, | ||
top_k=config.num_experts_per_tok, | ||
hidden_size=config.hidden_size, | ||
intermediate_size=config.intermediate_size, | ||
params_dtype=dtype, | ||
tp_size=1, | ||
) | ||
|
||
# Load the weights | ||
vllm_moe.gate.linear_weights["weight"][:] = hf_moe.gate.weight.data | ||
for i in range(config.num_local_experts): | ||
weights = (hf_moe.experts[i].w1.weight.data, | ||
hf_moe.experts[i].w3.weight.data) | ||
vllm_moe.ws[i][:] = torch.cat(weights, dim=0) | ||
vllm_moe.w2s[i][:] = hf_moe.experts[i].w2.weight.data | ||
|
||
# Generate input batch of dimensions [batch_size, seq_len, hidden_dim] | ||
inputs = torch.randn((1, 64, config.hidden_size)).to(dtype).to("cuda") | ||
|
||
# Run forward passes for both MoE blocks | ||
hf_states, _ = hf_moe.forward(inputs) | ||
vllm_states = vllm_moe.forward(inputs) | ||
|
||
mixtral_moe_tol = { | ||
torch.float32: 1e-3, | ||
torch.float16: 1e-3, | ||
torch.bfloat16: 1e-2, | ||
} | ||
|
||
assert torch.allclose(hf_states, | ||
vllm_states, | ||
rtol=mixtral_moe_tol[dtype], | ||
atol=mixtral_moe_tol[dtype]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters