Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Failure to Dispatch Head Dimension 80 in sglang with Specific Configurations #1109

Closed
5 tasks done
hxer7963 opened this issue Aug 15, 2024 · 6 comments
Closed
5 tasks done

Comments

@hxer7963
Copy link
Contributor

hxer7963 commented Aug 15, 2024

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/sglang/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

Issue Description:

When running sglang with hidden_dim set to 80, the following exceptions are encountered under different configurations:

With enable_cuda_graph set to True:

Exception: Capture cuda graph failed: BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(at::Tensor, at::Tensor, at::Tensor, at::Tensor, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, float, at::Tensor, at::Tensor)::<lambda()>::<lambda()>::<lambda()> failed to dispatch head_dim 80
The complete stack information is as follows:
python -m sglang.launch_server --model-path /mnt/llm_dataset/willhe/ckpt/XVERSE-MoE-A4.2B-Chat --trust-remote-code 
server_args=ServerArgs(model_path='/mnt/llm_dataset/willhe/ckpt/XVERSE-MoE-A4.2B-Chat', tokenizer_path='/mnt/llm_dataset/willhe/ckpt/XVERSE-MoE-A4.2B-Chat', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', dtype='auto', trust_remote_code=True, context_length=None, quantization=None, served_model_name='/mnt/llm_dataset/willhe/ckpt/XVERSE-MoE-A4.2B-Chat', chat_template=None, host='127.0.0.1', port=30000, additional_ports=[30001, 30002, 30003, 30004], mem_fraction_static=0.88, max_running_requests=None, max_num_reqs=None, max_total_tokens=None, chunked_prefill_size=None, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, tp_size=1, stream_interval=1, random_seed=334330058, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, api_key=None, file_storage_pth='SGLang_storage', dp_size=1, load_balance_method='round_robin', disable_flashinfer=False, disable_flashinfer_sampling=False, disable_radix_cache=False, disable_regex_jump_forward=False, disable_cuda_graph=False, disable_disk_cache=False, enable_torch_compile=False, enable_p2p_check=False, enable_mla=False, attention_reduce_in_fp32=False, efficient_weight_load=False, nccl_init_addr=None, nnodes=1, node_rank=None)
[gpu=0] Init nccl begin.
[gpu=0] Load weight begin. avail mem=78.94 GB
Loading pt checkpoint shards:   0% Completed | 0/9 [00:00<?, ?it/s]
...
Loading pt checkpoint shards: 100% Completed | 9/9 [00:21<00:00,  2.35s/it]

[gpu=0] Load weight end. type=XverseMoEForCausalLM, dtype=torch.bfloat16, avail mem=30.65 GB
[gpu=0] Memory pool end. avail mem=9.28 GB
[gpu=0] Capture cuda graph begin. This can take up to several minutes.
Initialization failed. controller_init_state: Traceback (most recent call last):
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 344, in init_cuda_graphs
    self.cuda_graph_runner.capture(batch_size_list)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/model_executor/cuda_graph_runner.py", line 148, in capture
    ) = self.capture_one_batch_size(bs, forward)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/model_executor/cuda_graph_runner.py", line 183, in capture_one_batch_size
    update_flashinfer_indices(
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/model_executor/forward_batch_info.py", line 284, in update_flashinfer_indices
    flashinfer_decode_wrapper.begin_forward(
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/flashinfer/decode.py", line 539, in begin_forward
    self._wrapper.begin_forward(
RuntimeError: BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(at::Tensor, at::Tensor, at::Tensor, at::Tensor, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, float, at::Tensor, at::Tensor)::<lambda()>::<lambda()>::<lambda()> failed to dispatch head_dim 80

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/managers/controller_single.py", line 150, in start_controller_process
    controller = ControllerSingle(
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/managers/controller_single.py", line 84, in __init__
    self.tp_server = ModelTpServer(
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/managers/tp_worker.py", line 100, in __init__
    self.model_runner = ModelRunner(
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 139, in __init__
    self.init_cuda_graphs()
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 346, in init_cuda_graphs
    raise Exception(
Exception: Capture cuda graph failed: BatchDecodeWithPagedKVCachePyTorchWrapper::BeginForward(at::Tensor, at::Tensor, at::Tensor, at::Tensor, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, unsigned int, float, at::Tensor, at::Tensor)::<lambda()>::<lambda()>::<lambda()> failed to dispatch head_dim 80
Possible solutions:
1. disable torch compile by not using --enable-torch-compile
2. disable cuda graph by --disable-cuda-graph
3. set --mem-fraction-static to a smaller value

With disable_cuda_graph set to True:

RuntimeError: BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, at::Tensor, at::Tensor, at::Tensor, bool, unsigned int, bool, int, float, float, float, float, bool)::<lambda()>::<lambda()>::<lambda()>::<lambda()>::<lambda()> failed to dispatch head_dim 80
The complete stack information is as follows:
python -m sglang.launch_server --model-path /mnt/llm_dataset/willhe/ckpt/XVERSE-MoE-A4.2B-Chat --trust-remote-code --disable-cuda-graph
server_args=ServerArgs(model_path='/mnt/llm_dataset/willhe/ckpt/XVERSE-MoE-A4.2B-Chat', tokenizer_path='/mnt/llm_dataset/willhe/ckpt/XVERSE-MoE-A4.2B-Chat', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', dtype='auto', trust_remote_code=True, context_length=None, quantization=None, served_model_name='/mnt/llm_dataset/willhe/ckpt/XVERSE-MoE-A4.2B-Chat', chat_template=None, host='127.0.0.1', port=30000, additional_ports=[30001, 30002, 30003, 30004], mem_fraction_static=0.88, max_running_requests=None, max_num_reqs=None, max_total_tokens=None, chunked_prefill_size=None, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, tp_size=1, stream_interval=1, random_seed=978056517, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, api_key=None, file_storage_pth='SGLang_storage', dp_size=1, load_balance_method='round_robin', disable_flashinfer=False, disable_flashinfer_sampling=False, disable_radix_cache=False, disable_regex_jump_forward=False, disable_cuda_graph=True, disable_disk_cache=False, enable_torch_compile=False, enable_p2p_check=False, enable_mla=False, attention_reduce_in_fp32=False, efficient_weight_load=False, nccl_init_addr=None, nnodes=1, node_rank=None)
[gpu=0] Init nccl begin.
[gpu=0] Load weight begin. avail mem=78.94 GB
Loading pt checkpoint shards:   0% Completed | 0/9 [00:00<?, ?it/s]
...
Loading pt checkpoint shards: 100% Completed | 9/9 [00:17<00:00,  1.90s/it]

[gpu=0] Load weight end. type=XverseMoEForCausalLM, dtype=torch.bfloat16, avail mem=30.65 GB
[gpu=0] Memory pool end. avail mem=9.28 GB
[gpu=0] max_total_num_tokens=79301, max_prefill_tokens=16384, max_running_requests=4955, context_len=8192
INFO:     Started server process [54724]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://127.0.0.1:30000 (Press CTRL+C to quit)
INFO:     127.0.0.1:53136 - "GET /get_model_info HTTP/1.1" 200 OK
[gpu=0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 0, cache hit rate: 0.00%, #running-req: 0, #queue-req: 0
Exception in ModelTpServer:
Traceback (most recent call last):
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/managers/tp_worker.py", line 222, in exposed_step
    self.forward_step()
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/managers/tp_worker.py", line 238, in forward_step
    self.forward_prefill_batch(new_batch)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/managers/tp_worker.py", line 452, in forward_prefill_batch
    output = self.model_runner.forward(batch, ForwardMode.EXTEND)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 397, in forward
    return self.forward_extend(batch)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 373, in forward_extend
    return self.model.forward(
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/models/xverse_moe.py", line 389, in forward
    hidden_states = self.model(input_ids, positions, input_metadata)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/models/xverse_moe.py", line 358, in forward
    hidden_states, residual = layer(
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/models/xverse_moe.py", line 308, in forward
    hidden_states = self.self_attn(
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/models/xverse_moe.py", line 251, in forward
    **attn_output = self.attn(q, k, v, input_metadata)**
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/layers/radix_attention.py", line 177, in forward
    return self.extend_forward(q, k, v, input_metadata)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/layers/radix_attention.py", line 119, in extend_forward_flashinfer
    o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/flashinfer/prefill.py", line 914, in forward
    out = self._wrapper.forward(
RuntimeError: BatchPrefillWithPagedKVCachePyTorchWrapper::Forward(at::Tensor, at::Tensor, std::optional<at::Tensor>, std::optional<at::Tensor>, std::optional<at::Tensor>, at::Tensor, at::Tensor, at::Tensor, bool, unsigned int, bool, int, float, float, float, float, bool)::<lambda()>::<lambda()>::<lambda()>::<lambda()>::<lambda()> failed to dispatch head_dim 80

With disable_flashinfer set to True:

AssertionError: assert Lq in {16, 32, 64, 128, 256, 576}, where Lq = 80.
The complete stack information is as follows:
python -m sglang.launch_server --model-path /mnt/llm_dataset/willhe/ckpt/XVERSE-MoE-A4.2B-Chat --trust-remote-code --disable-flashinfer
server_args=ServerArgs(model_path='/mnt/llm_dataset/willhe/ckpt/XVERSE-MoE-A4.2B-Chat', tokenizer_path='/mnt/llm_dataset/willhe/ckpt/XVERSE-MoE-A4.2B-Chat', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', dtype='auto', trust_remote_code=True, context_length=None, quantization=None, served_model_name='/mnt/llm_dataset/willhe/ckpt/XVERSE-MoE-A4.2B-Chat', chat_template=None, host='127.0.0.1', port=30000, additional_ports=[30001, 30002, 30003, 30004], mem_fraction_static=0.88, max_running_requests=None, max_num_reqs=None, max_total_tokens=None, chunked_prefill_size=None, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, tp_size=1, stream_interval=1, random_seed=8981036, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, api_key=None, file_storage_pth='SGLang_storage', dp_size=1, load_balance_method='round_robin', disable_flashinfer=True, disable_flashinfer_sampling=False, disable_radix_cache=False, disable_regex_jump_forward=False, disable_cuda_graph=False, disable_disk_cache=False, enable_torch_compile=False, enable_p2p_check=False, enable_mla=False, attention_reduce_in_fp32=False, efficient_weight_load=False, nccl_init_addr=None, nnodes=1, node_rank=None)
[gpu=0] Init nccl begin.
[gpu=0] Load weight begin. avail mem=78.94 GB
Loading pt checkpoint shards:   0% Completed | 0/9 [00:00<?, ?it/s]
...
Loading pt checkpoint shards: 100% Completed | 9/9 [00:17<00:00,  1.92s/it]

[gpu=0] Load weight end. type=XverseMoEForCausalLM, dtype=torch.bfloat16, avail mem=30.65 GB
[gpu=0] Memory pool end. avail mem=9.28 GB
[gpu=0] max_total_num_tokens=79301, max_prefill_tokens=16384, max_running_requests=4955, context_len=8192
INFO:     Started server process [55058]
INFO:     Waiting for application startup.
INFO:     Application startup complete.
INFO:     Uvicorn running on http://127.0.0.1:30000 (Press CTRL+C to quit)
INFO:     127.0.0.1:60976 - "GET /get_model_info HTTP/1.1" 200 OK
[gpu=0] Prefill batch. #new-seq: 1, #new-token: 8, #cached-token: 0, cache hit rate: 0.00%, #running-req: 0, #queue-req: 0
Exception in ModelTpServer:
Traceback (most recent call last):
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/managers/tp_worker.py", line 222, in exposed_step
    self.forward_step()
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/managers/tp_worker.py", line 238, in forward_step
    self.forward_prefill_batch(new_batch)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/managers/tp_worker.py", line 452, in forward_prefill_batch
    output = self.model_runner.forward(batch, ForwardMode.EXTEND)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 397, in forward
    return self.forward_extend(batch)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/model_executor/model_runner.py", line 373, in forward_extend
    return self.model.forward(
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/models/xverse_moe.py", line 389, in forward
    hidden_states = self.model(input_ids, positions, input_metadata)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/models/xverse_moe.py", line 358, in forward
    hidden_states, residual = layer(
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/models/xverse_moe.py", line 308, in forward
    hidden_states = self.self_attn(
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/models/xverse_moe.py", line 251, in forward
    attn_output = self.attn(q, k, v, input_metadata)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/layers/radix_attention.py", line 177, in forward
    return self.extend_forward(q, k, v, input_metadata)
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/layers/radix_attention.py", line 69, in extend_forward_triton
    extend_attention_fwd(
  File "/mnt/llm_dataset/willhe/miniconda3/envs/sglang/lib/python3.10/site-packages/sglang/srt/layers/extend_attention.py", line 267, in extend_attention_fwd
    assert Lq in {16, 32, 64, 128, 256, 576}
AssertionError

Reproduction

run command

python -m sglang.launch_server --model-path xverse/XVERSE-MoE-A4.2B-Chat --trust-remote-code --disable-flashinfer

model

xverse/XVERSE-MoE-A4.2B-Chat

Steps to Reproduce:

I am currently supporting XVERSE models in a local branch and encountered an issue where the server failed to launch. Although the issue cannot be replicated directly, possible solutions may be inferred from the error message and the configuration file.

Below is the config.json file, where the hidden_size/num_attention_heads = 2560/32= 80.

{
  "architectures": [
    "XverseMoEForCausalLM"
  ],
  "auto_map": {
    "AutoConfig": "configuration_xverse.XverseConfig",
    "AutoModelForCausalLM": "modeling_xverse.XverseForCausalLM"
  },
  "pad_token_id": 1,
  "bos_token_id": 2,
  "eos_token_id": 3,
  "hidden_act": "silu",
  "hidden_size": 2560,
  "initializer_range": 0.02,
  "intermediate_size": 1728,
  "max_position_embeddings": 8192,
  "max_tokenizer_truncation": 6144,
  "model_type": "xverse",
  "num_attention_heads": 32,
  "num_hidden_layers": 28,
  "rms_norm_eps": 1e-06,
  "tie_word_embeddings": false,
  "rope_theta": 500000,
  "moe_top_k": 6,
  "num_experts": 64,
  "num_shared_experts": 2,
  "output_router_logits": false,
  "router_aux_loss_coef": 0.01,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.38.2",
  "use_cache": true,
  "vocab_size": 100534,
  "_attn_implementation": "eager"
}

Environment

Python: 3.10.0 (default, Mar  3 2022, 09:58:08) [GCC 7.5.0]
CUDA available: True
GPU 0: NVIDIA A800-SXM4-80GB
GPU 0 Compute Capability: 8.0
CUDA_HOME: /usr/local/cuda
NVCC: Cuda compilation tools, release 12.4, V12.4.131
CUDA Driver Version: 470.182.03
PyTorch: 2.3.1+cu121
sglang: 0.2.12
flashinfer: 0.1.5+cu121torch2.3
triton: 2.3.1
transformers: 4.43.3
requests: 2.32.3
tqdm: 4.66.4
numpy: 1.26.4
aiohttp: 3.9.5
fastapi: 0.111.1
hf_transfer: 0.1.8
huggingface_hub: 0.24.2
interegular: 0.3.3
packaging: 24.1
PIL: 10.4.0
psutil: 6.0.0
pydantic: 2.8.2
uvicorn: 0.30.3
uvloop: 0.19.0
zmq: 26.0.3
vllm: 0.5.3.post1
multipart: 0.0.9
openai: 1.37.1
anthropic: 0.31.2
NVIDIA Topology: 
        GPU0    CPU Affinity    NUMA Affinity
GPU0     X      116-231 1

Legend:

  X    = Self
  SYS  = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
  NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
  PHB  = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
  PXB  = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
  PIX  = Connection traversing at most a single PCIe bridge
  NV#  = Connection traversing a bonded set of # NVLinks

ulimit soft: 1048576
@hxer7963
Copy link
Contributor Author

By the way, I'm waiting to fix the failure to Dispatch head_dim=80, and then I'll merge the model support local branch into the online main branch.

@zhyncs
Copy link
Member

zhyncs commented Aug 17, 2024

Consider trying support in FlashInfer or the Triton kernel. From the perspective of rapid verification, I suggest you modify the Triton kernel.

@ByronHsu
Copy link
Collaborator

ByronHsu commented Sep 1, 2024

⚡ byhsu/fix-border ~/sglang PYTHONPATH="/teamspace/studios/this_studio/sglang/python" python -m sglang.launch_server --model-path xverse/XVERSE-MoE-A4.2B-Chat --trust-remote-code --disable-flashinfer
[07:58:52] server_args=ServerArgs(model_path='xverse/XVERSE-MoE-A4.2B-Chat', tokenizer_path='xverse/XVERSE-MoE-A4.2B-Chat', tokenizer_mode='auto', skip_tokenizer_init=False, load_format='auto', dtype='auto', kv_cache_dtype='auto', trust_remote_code=True, context_length=None, quantization=None, served_model_name='xverse/XVERSE-MoE-A4.2B-Chat', chat_template=None, is_embedding=False, host='127.0.0.1', port=30000, additional_ports=[30001, 30002, 30003, 30004], mem_fraction_static=0.88, max_running_requests=None, max_num_reqs=None, max_total_tokens=None, chunked_prefill_size=8192, max_prefill_tokens=16384, schedule_policy='lpm', schedule_conservativeness=1.0, tp_size=1, stream_interval=1, random_seed=262099981, log_level='info', log_level_http=None, log_requests=False, show_time_cost=False, api_key=None, file_storage_pth='SGLang_storage', dp_size=1, load_balance_method='round_robin', disable_flashinfer=True, disable_flashinfer_sampling=False, disable_radix_cache=False, disable_regex_jump_forward=False, disable_cuda_graph=False, disable_cuda_graph_padding=False, disable_disk_cache=False, disable_custom_all_reduce=False, enable_mixed_chunk=False, enable_torch_compile=False, enable_p2p_check=False, enable_mla=False, triton_attention_reduce_in_fp32=False, nccl_init_addr=None, nnodes=1, node_rank=None)
Traceback (most recent call last):
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/teamspace/studios/this_studio/sglang/python/sglang/launch_server.py", line 19, in <module>
    raise e
  File "/teamspace/studios/this_studio/sglang/python/sglang/launch_server.py", line 17, in <module>
    launch_server(server_args)
  File "/teamspace/studios/this_studio/sglang/python/sglang/srt/server.py", line 331, in launch_server
    tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
  File "/teamspace/studios/this_studio/sglang/python/sglang/srt/managers/tokenizer_manager.py", line 128, in __init__
    self.tokenizer = get_tokenizer(
  File "/teamspace/studios/this_studio/sglang/python/sglang/srt/hf_transformers_utils.py", line 129, in get_tokenizer
    tokenizer = AutoTokenizer.from_pretrained(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/transformers/models/auto/tokenization_auto.py", line 897, in from_pretrained
    return tokenizer_class.from_pretrained(pretrained_model_name_or_path, *inputs, **kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2271, in from_pretrained
    return cls._from_pretrained(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/transformers/tokenization_utils_base.py", line 2505, in _from_pretrained
    tokenizer = cls(*init_inputs, **init_kwargs)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/transformers/tokenization_utils_fast.py", line 115, in __init__
    fast_tokenizer = TokenizerFast.from_file(fast_tokenizer_file)
Exception: data did not match any variant of untagged enum PyPreTokenizerTypeWrapper at line 78 column 3

I hit a different error. not sure what is wrong on my side.

@ByronHsu
Copy link
Collaborator

ByronHsu commented Sep 1, 2024

I am fixing the kernel for non 2^n case. Happy to resolve this together, but i am not able to reproduce.

@janimo
Copy link
Contributor

janimo commented Sep 1, 2024

@ByronHsu you can try with the existing stablelm model code (which already supports stablelm-2) and the stabilityai/stablelm-3b-4e1t model name. The latter also has a head dim of 80. Your fix makes it work.

@ByronHsu
Copy link
Collaborator

@zhyncs we can close this now

@zhyncs zhyncs closed this as completed Sep 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants