forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Bamba Model (vllm-project#10909)
Signed-off-by: Yu Chin Fabian Lim <[email protected]> Signed-off-by: Tyler Michael Smith <[email protected]> Co-authored-by: Tyler Michael Smith <[email protected]> Signed-off-by: Felix Marty <[email protected]>
- Loading branch information
1 parent
89ebc6b
commit a7cb3cc
Showing
17 changed files
with
3,706 additions
and
112 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import unittest | ||
from typing import Tuple | ||
|
||
import pytest | ||
import torch | ||
|
||
from tests.utils import multi_gpu_test | ||
from vllm.distributed.parallel_state import (init_distributed_environment, | ||
initialize_model_parallel) | ||
from vllm.model_executor.layers.mamba.mamba_mixer2 import Mixer2RMSNormGated | ||
from vllm.platforms import current_platform | ||
from vllm.utils import update_environment_variables | ||
|
||
|
||
@multi_gpu_test(num_gpus=2) | ||
@pytest.mark.parametrize("batch_size", [8]) | ||
@pytest.mark.parametrize("seq_len", [128]) | ||
@pytest.mark.parametrize( | ||
"hidden_size_n_groups", | ||
[ | ||
(64, 1), | ||
(64, 2), | ||
(64, 4), # hidden_size be divisible by num_gpus | ||
(100, 5), # and n_groups must divide hidden_size | ||
]) | ||
@pytest.mark.parametrize("dtype", [torch.float16]) | ||
def test_mixer2_gated_norm_multi_gpu( | ||
batch_size: int, | ||
seq_len: int, | ||
hidden_size_n_groups: Tuple[int, int], | ||
dtype: torch.dtype, | ||
device: str = 'cuda', | ||
): | ||
hidden_size, n_groups = hidden_size_n_groups | ||
num_processes = 2 | ||
|
||
def run_torch_spawn(fn, nprocs): | ||
# need to use torch.mp.spawn otherwise will have problems with | ||
# torch.distributed and cuda | ||
torch.multiprocessing.spawn(fn, | ||
args=( | ||
num_processes, | ||
batch_size, | ||
seq_len, | ||
hidden_size, | ||
n_groups, | ||
dtype, | ||
device, | ||
), | ||
nprocs=nprocs) | ||
|
||
run_torch_spawn(mixer2_gated_norm_tensor_parallel, 2) | ||
|
||
|
||
def mixer2_gated_norm_tensor_parallel( | ||
local_rank: int, | ||
world_size: int, | ||
batch_size: int, | ||
seq_len: int, | ||
hidden_size: int, | ||
n_groups: int, | ||
dtype: torch.dtype, | ||
device: str, | ||
): | ||
current_platform.seed_everything(0) | ||
|
||
device = torch.device(f"cuda:{local_rank}") | ||
torch.cuda.set_device(device) | ||
torch.set_default_device(device) | ||
torch.set_default_dtype(dtype) | ||
|
||
update_environment_variables({ | ||
'RANK': str(local_rank), | ||
'LOCAL_RANK': str(local_rank), | ||
'WORLD_SIZE': str(world_size), | ||
'MASTER_ADDR': 'localhost', | ||
'MASTER_PORT': '12345', | ||
}) | ||
|
||
# initialize distributed | ||
init_distributed_environment() | ||
initialize_model_parallel(tensor_model_parallel_size=world_size) | ||
|
||
# create random weights an inputs | ||
weight = torch.rand((hidden_size, ), dtype=dtype, device=device) | ||
hidden_states = torch.randn(batch_size, seq_len, hidden_size) | ||
gate_states = torch.randn(batch_size, seq_len, hidden_size) | ||
|
||
# create gated-norm with TP | ||
mixer = Mixer2RMSNormGated( | ||
full_hidden_size=hidden_size, | ||
full_n_groups=n_groups, | ||
) | ||
mixer.weight.weight_loader(mixer.weight, weight) # load | ||
|
||
# create gated-norm without TP to compute reference | ||
# - utilize mock patching to disable TP when | ||
with (unittest.mock.patch( | ||
"vllm.model_executor.layers.mamba.mamba_mixer2." | ||
"get_tensor_model_parallel_world_size", | ||
return_value=1), | ||
unittest.mock.patch( | ||
"vllm.model_executor.layers.mamba.mamba_mixer2." | ||
"get_tensor_model_parallel_rank", | ||
return_value=0)): | ||
mixer_single_gpu = Mixer2RMSNormGated( | ||
full_hidden_size=hidden_size, | ||
full_n_groups=n_groups, | ||
) | ||
# assign weight to single-gpu mixer | ||
mixer_single_gpu.weight.data = weight | ||
|
||
# generate and compare | ||
N = hidden_size // world_size | ||
output = mixer( | ||
hidden_states[..., local_rank * N:(local_rank + 1) * N], | ||
gate_states[..., local_rank * N:(local_rank + 1) * N], | ||
) | ||
ref_output = mixer_single_gpu(hidden_states, gate_states) | ||
torch.allclose(output, | ||
ref_output[..., local_rank * N:(local_rank + 1) * N], | ||
atol=1e-3, | ||
rtol=1e-3) |
Oops, something went wrong.