From 8e6bdf851c4aa6619baa584fc450af748720319d Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Mon, 9 Sep 2024 01:30:24 -0700 Subject: [PATCH 1/5] [triton] Support head_dim not 2^n in triton extend and decode attention (#1281) --- python/sglang/srt/layers/decode_attention.py | 50 ++++++++++++------ python/sglang/srt/layers/extend_attention.py | 51 +++++++++++++------ python/sglang/srt/layers/prefill_attention.py | 20 +++++--- 3 files changed, 84 insertions(+), 37 deletions(-) diff --git a/python/sglang/srt/layers/decode_attention.py b/python/sglang/srt/layers/decode_attention.py index dc92a65480c..9c9822b8528 100644 --- a/python/sglang/srt/layers/decode_attention.py +++ b/python/sglang/srt/layers/decode_attention.py @@ -60,6 +60,7 @@ def _fwd_kernel_stage1( BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, logit_cap: tl.constexpr, + Lk: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -97,7 +98,7 @@ def _fwd_kernel_stage1( ) k = tl.load( K_Buffer + offs_buf_k, - mask=offs_n_new[:, None] < cur_batch_end_index, + mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < Lk), other=0.0, ).to(REDUCE_TRITON_TYPE) att_value = tl.sum(q[None, :] * k, 1) @@ -128,6 +129,7 @@ def _fwd_kernel_stage2( kv_group_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -170,14 +172,16 @@ def _fwd_kernel_stage2( old_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max) e_sum = e_sum * old_scale + tl.sum(p, 0) - v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs) + v = tl.load( + v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv) + ) acc = acc * old_scale + tl.sum(p[:, None] * v, 0) e_max = n_e_max acc = acc / e_sum off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d out_ptrs = Out + off_o - tl.store(out_ptrs, acc) + tl.store(out_ptrs, acc, mask=(offs_d < Lv)) def _decode_att_m_fwd( @@ -196,7 +200,7 @@ def _decode_att_m_fwd( # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 96, 128, 256} batch, head_num = B_req_idx.shape[0], q.shape[1] @@ -208,6 +212,8 @@ def _decode_att_m_fwd( else: num_warps = 2 + BLOCK_DMODEL = triton.next_power_of_2(Lk) + _fwd_kernel_stage1[grid]( q, k_buffer, @@ -224,11 +230,12 @@ def _decode_att_m_fwd( k_buffer.stride(1), att_out.stride(0), kv_group_num=kv_group_num, - BLOCK_DMODEL=Lk, + BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK, logit_cap=logit_cap, num_warps=num_warps, num_stages=1, + Lk=Lk, ) @@ -248,6 +255,9 @@ def _decode_softmax_reducev_fwd( num_warps = 1 + Lv = v_buffer.shape[-1] + BLOCK_DMODEL = triton.next_power_of_2(Lv) + _fwd_kernel_stage2[grid]( logics, v_buffer, @@ -263,10 +273,11 @@ def _decode_softmax_reducev_fwd( o.stride(1), req_to_tokens.stride(0), kv_group_num=kv_group_num, - BLOCK_DMODEL=v_buffer.shape[-1], + BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=3, + Lv=Lv, ) @@ -293,6 +304,7 @@ def _fwd_grouped_kernel_stage1( BLOCK_N: tl.constexpr, BLOCK_H: tl.constexpr, logit_cap: tl.constexpr, + Lk: tl.constexpr, ): cur_batch = tl.program_id(0) cur_kv_head = tl.program_id(1) @@ -324,9 +336,9 @@ def _fwd_grouped_kernel_stage1( block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) for start_mark in range(0, block_mask, 1): - q = tl.load(Q + offs_q + start_mark, mask=mask_h[:, None]).to( - REDUCE_TRITON_TYPE - ) + q = tl.load( + Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk) + ).to(REDUCE_TRITON_TYPE) offs_n_new = cur_batch_start_index + offs_n k_loc = tl.load( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, @@ -340,7 +352,7 @@ def _fwd_grouped_kernel_stage1( ) k = tl.load( K_Buffer + offs_buf_k, - mask=offs_n_new[None, :] < cur_batch_end_index, + mask=(offs_n_new[None, :] < cur_batch_end_index) & (offs_d[:, None] < Lk), other=0.0, ).to(REDUCE_TRITON_TYPE) qk = tl.dot(q, k) @@ -395,6 +407,7 @@ def _fwd_grouped_kernel_stage2( BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_H: tl.constexpr, + Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_kv_head = tl.program_id(1) @@ -441,7 +454,9 @@ def _fwd_grouped_kernel_stage2( old_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max[:, None]) e_sum = e_sum * old_scale + tl.sum(p, 1) - v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs) + v = tl.load( + v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv) + ) p = p.to(v.dtype) acc = acc * old_scale[:, None] + tl.dot(p, v) e_max = n_e_max @@ -449,7 +464,7 @@ def _fwd_grouped_kernel_stage2( acc = acc / e_sum[:, None] off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :] out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=mask_h[:, None]) + tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv)) def _decode_grouped_att_m_fwd( @@ -468,13 +483,13 @@ def _decode_grouped_att_m_fwd( # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128, 256, 576} + assert Lk in {16, 32, 64, 96, 128, 256, 576} if Lk == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 else: - BLOCK_DMODEL = Lk + BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DPE = 0 batch, head_num = B_req_idx.shape[0], q.shape[1] @@ -513,6 +528,7 @@ def _decode_grouped_att_m_fwd( logit_cap=logit_cap, num_warps=num_warps, num_stages=1, + Lk=Lk, ) @@ -533,6 +549,9 @@ def _decode_grouped_softmax_reducev_fwd( num_warps = 8 + Lv = v_buffer.shape[-1] + BLOCK_DMODEL = triton.next_power_of_2(Lv) + _fwd_grouped_kernel_stage2[grid]( logics, v_buffer, @@ -549,11 +568,12 @@ def _decode_grouped_softmax_reducev_fwd( req_to_tokens.stride(0), kv_group_num=kv_group_num, q_head_num=head_num, - BLOCK_DMODEL=v_buffer.shape[-1], + BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK, BLOCK_H=BLOCK_H, num_warps=num_warps, num_stages=1, + Lv=Lv, ) diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index 6c7686971e0..8880622854d 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -15,7 +15,7 @@ """ Memory-efficient attention for prefill. -It supporst page size = 1 and prefill with KV cache (i.e. extend). +It supports page size = 1 and prefill with KV cache (i.e. extend). """ import torch @@ -67,6 +67,8 @@ def _fwd_kernel( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, logit_cap: tl.constexpr, + Lq: tl.constexpr, + Lv: tl.constexpr, ): cur_seq = tl.program_id(0) cur_head = tl.program_id(1) @@ -86,13 +88,18 @@ def _fwd_kernel( offs_m = tl.arange(0, BLOCK_M) mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend + mask_d = offs_d < Lq + mask_dv = offs_dv < Lv + offs_q = ( (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] ) - q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0) + q = tl.load( + Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0 + ) if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) @@ -125,7 +132,9 @@ def _fwd_kernel( + cur_kv_head * stride_buf_kh + offs_d[:, None] ) - k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0) + k = tl.load( + K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0 + ) qk = tl.dot(q.to(k.dtype), k) if BLOCK_DPE > 0: @@ -157,7 +166,9 @@ def _fwd_kernel( + cur_kv_head * stride_buf_vh + offs_dv[None, :] ) - v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0) + v = tl.load( + V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0 + ) p = p.to(v.dtype) acc = acc * re_scale[:, None] + tl.dot(p, v) @@ -176,7 +187,9 @@ def _fwd_kernel( + cur_kv_head * stride_kh + offs_d[:, None] ) - k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0) + k = tl.load( + K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0 + ) qk = tl.dot(q, k, out_dtype=tl.float32) if BLOCK_DPE > 0: @@ -214,7 +227,9 @@ def _fwd_kernel( + cur_kv_head * stride_vh + offs_dv[None, :] ) - v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0) + v = tl.load( + V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0 + ) p = p.to(v.dtype) acc = acc * re_scale[:, None] + tl.dot(p, v) @@ -226,7 +241,9 @@ def _fwd_kernel( + cur_head * stride_oh + offs_dv[None, :] ) - tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None]) + tl.store( + O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :] + ) def extend_attention_fwd( @@ -261,16 +278,18 @@ def extend_attention_fwd( ) assert Lq == Lk and Lv == Lo - assert Lq in {16, 32, 64, 128, 256, 576} - assert Lv in {16, 32, 64, 128, 256, 512} + + # TODO: is the assertion necessary? + assert Lq in {16, 32, 64, 96, 128, 256, 576} + assert Lv in {16, 32, 64, 96, 128, 256, 512} if Lq == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 else: - BLOCK_DMODEL = Lq + BLOCK_DMODEL = triton.next_power_of_2(Lq) BLOCK_DPE = 0 - BLOCK_DV = Lv + BLOCK_DV = triton.next_power_of_2(Lv) if CUDA_CAPABILITY[0] >= 9: if Lq <= 256: @@ -330,6 +349,8 @@ def extend_attention_fwd( num_warps=num_warps, num_stages=num_stages, logit_cap=logit_cap, + Lq=Lq, + Lv=Lv, ) @@ -373,10 +394,7 @@ def redundant_attention( pt += cur_seq_len_extend -def test(): - torch.manual_seed(0) - - B, N_CTX, H_Q, H_KV, D = 19, 12331, 12, 4, 128 +def test_once(B, N_CTX, H_Q, H_KV, D): dtype = torch.float16 b_seq_len_prefix = torch.randint( @@ -473,4 +491,5 @@ def test(): if __name__ == "__main__": - test() + test_once(19, 12331, 12, 4, 128) + test_once(19, 12331, 12, 4, 96) diff --git a/python/sglang/srt/layers/prefill_attention.py b/python/sglang/srt/layers/prefill_attention.py index 99343a4df7c..fbf9976fbc5 100644 --- a/python/sglang/srt/layers/prefill_attention.py +++ b/python/sglang/srt/layers/prefill_attention.py @@ -48,6 +48,7 @@ def _fwd_kernel( BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + Lk: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -72,7 +73,11 @@ def _fwd_kernel( off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + mask_d = offs_d < Lk + + q = tl.load( + Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d), other=0.0 + ) k_ptrs = K + off_k v_ptrs = V + off_v @@ -89,7 +94,7 @@ def _fwd_kernel( # -- compute qk ---- k = tl.load( k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]), other=0.0, ) # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) @@ -118,7 +123,7 @@ def _fwd_kernel( # update acc v = tl.load( v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]), other=0.0, ) @@ -134,7 +139,9 @@ def _fwd_kernel( + offs_d[None, :] ) out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + tl.store( + out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]) + ) def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): @@ -145,7 +152,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 96, 128, 256} sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] @@ -172,8 +179,9 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): o.stride(1), kv_group_num=kv_group_num, BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, + BLOCK_DMODEL=triton.next_power_of_2(Lk), BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1, + Lk=Lk, ) From 662ecd93680c8195eda799cb9a497f93efdc521a Mon Sep 17 00:00:00 2001 From: Kaichen Zhang - NTU Date: Mon, 9 Sep 2024 17:07:34 +0800 Subject: [PATCH 2/5] [Feat] Add modalities for vision server when handling pixel values for llava (#1346) --- .../llava_onevision/http_llava_onevision_test.py | 3 +++ python/sglang/srt/conversation.py | 3 +++ python/sglang/srt/managers/io_struct.py | 4 ++++ python/sglang/srt/managers/schedule_batch.py | 1 + python/sglang/srt/managers/tokenizer_manager.py | 5 +++++ python/sglang/srt/managers/tp_worker.py | 2 ++ .../sglang/srt/model_executor/forward_batch_info.py | 2 ++ python/sglang/srt/models/llava.py | 11 +++++++++-- python/sglang/srt/openai_api/adapter.py | 7 +++++++ python/sglang/srt/openai_api/protocol.py | 1 + test/srt/test_vision_openai_server.py | 3 +++ 11 files changed, 40 insertions(+), 2 deletions(-) diff --git a/examples/runtime/llava_onevision/http_llava_onevision_test.py b/examples/runtime/llava_onevision/http_llava_onevision_test.py index 0c93d2ce2b2..2c7c2bd38be 100644 --- a/examples/runtime/llava_onevision/http_llava_onevision_test.py +++ b/examples/runtime/llava_onevision/http_llava_onevision_test.py @@ -93,12 +93,14 @@ def multi_image_stream_request_test(client): "image_url": { "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" }, + "modalities": "multi-images", }, { "type": "image_url", "image_url": { "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png" }, + "modalities": "multi-images", }, { "type": "text", @@ -218,6 +220,7 @@ def prepare_video_messages(video_path): frame_format = { "type": "image_url", "image_url": {"url": "data:image/jpeg;base64,{}"}, + "modalities": "video", } for base64_frame in base64_frames: diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index dbc376d9593..9a1227218b7 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -71,6 +71,7 @@ class Conversation: # Stop criteria (the default one is EOS token) stop_str: Union[str, List[str]] = None image_data: Optional[List[str]] = None + modalities: Optional[List[str]] = None def get_prompt(self) -> str: """Get the prompt for generation.""" @@ -379,6 +380,7 @@ def generate_chat_conv( sep2=conv.sep2, stop_str=conv.stop_str, image_data=[], + modalities=[], ) if isinstance(request.messages, str): @@ -408,6 +410,7 @@ def generate_chat_conv( for content in message.content: if content.type == "image_url": num_image_url += 1 + conv.modalities.append(content.modalities) if num_image_url > 1: image_token = "" else: diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 5b91ff62e9d..8e53df33555 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -50,6 +50,8 @@ class GenerateReqInput: return_text_in_logprobs: bool = False # Whether to stream output. stream: bool = False + # The modalities of the image data [image, multi-images, video] + modalities: Optional[List[str]] = None def post_init(self): if (self.text is None and self.input_ids is None) or ( @@ -177,6 +179,8 @@ class TokenizedGenerateReqInput: top_logprobs_num: int # Whether to stream output stream: bool + # Modalities of the input images + modalites: Optional[List[str]] = None @dataclass diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c80cf2e2723..f126cc9f3ae 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -130,6 +130,7 @@ def __init__(self, rid, origin_input_text, origin_input_ids): self.image_sizes = None self.image_offsets = None self.pad_value = None + self.modalities = None # Prefix info self.extend_input_len = 0 diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 6af82064152..d0cfed08cd1 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -188,6 +188,7 @@ async def _handle_single_request( pixel_values, image_hashes, image_sizes = await self._get_pixel_values( obj.image_data if not_use_index else obj.image_data[index] ) + modalities = obj.modalities return_logprob = ( obj.return_logprob if not_use_index else obj.return_logprob[index] ) @@ -243,6 +244,7 @@ async def _handle_single_request( pixel_values, image_hashes, image_sizes = await self._get_pixel_values( obj.image_data[0] ) + modalities = obj.modalities return_logprob = obj.return_logprob[0] logprob_start_len = obj.logprob_start_len[0] top_logprobs_num = obj.top_logprobs_num[0] @@ -263,6 +265,7 @@ async def _handle_single_request( logprob_start_len, top_logprobs_num, obj.stream, + modalities, ) else: # is embedding tokenized_obj = TokenizedEmbeddingReqInput( @@ -346,6 +349,7 @@ async def _handle_batch_request( pixel_values, image_hashes, image_sizes = ( await self._get_pixel_values(obj.image_data[index]) ) + modalities = obj.modalities tokenized_obj = TokenizedGenerateReqInput( rid, @@ -359,6 +363,7 @@ async def _handle_batch_request( obj.logprob_start_len[index], obj.top_logprobs_num[index], obj.stream, + modalities, ) else: tokenized_obj = TokenizedEmbeddingReqInput( diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index c2c0e6c2d14..7bb9c433565 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -358,6 +358,8 @@ def handle_generate_request( req.pixel_values, req.image_sizes, ) + # Only when pixel values is not None we have modalities + req.modalities = recv_req.modalites req.return_logprob = recv_req.return_logprob req.logprob_start_len = recv_req.logprob_start_len req.top_logprobs_num = recv_req.top_logprobs_num diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index a443b113d44..75f9136d398 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -78,6 +78,7 @@ class InputMetadata: pixel_values: List[torch.Tensor] = None image_sizes: List[List[List[int]]] = None image_offsets: List[List[int]] = None + modalities: List[List[str]] = None # Trition attention backend triton_max_seq_len: int = 0 @@ -96,6 +97,7 @@ def init_multimuldal_info(self, batch: ScheduleBatch): self.pixel_values = [r.pixel_values for r in reqs] self.image_sizes = [r.image_sizes for r in reqs] self.image_offsets = [r.image_offsets for r in reqs] + self.modalities = [r.modalities for r in reqs] def compute_positions(self, batch: ScheduleBatch): position_ids_offsets = batch.position_ids_offsets diff --git a/python/sglang/srt/models/llava.py b/python/sglang/srt/models/llava.py index 2e3c9ceba1a..62041a89553 100644 --- a/python/sglang/srt/models/llava.py +++ b/python/sglang/srt/models/llava.py @@ -138,6 +138,12 @@ def forward( ) -> torch.Tensor: if input_metadata.forward_mode == ForwardMode.EXTEND: bs = input_metadata.batch_size + # Got List[List[str]] extend it to List[str] + # The length of the List should be equal to batch size + modalities_list = [] + for modalities in input_metadata.modalities: + if modalities is not None: + modalities_list.extend(modalities) # Embed text inputs input_embeds = self.language_model.model.embed_tokens(input_ids) @@ -179,7 +185,7 @@ def forward( new_image_features = [] height = width = self.num_patches_per_side for image_idx, image_feature in enumerate(image_features): - if len(image_sizes[image_idx]) == 1: + if modalities_list[image_idx] == 1: image_aspect_ratio = ( self.config.image_aspect_ratio ) # single image @@ -191,6 +197,7 @@ def forward( if ( image_feature.shape[0] > 1 and "anyres" in image_aspect_ratio + and modalities_list[image_idx] == "image" ): base_image_feature = image_feature[0] image_feature = image_feature[1:] @@ -290,7 +297,7 @@ def forward( ) image_feature = image_feature.unsqueeze(0) else: - if image_feature.shape[0] > 16: # video + if modalities_list[image_idx] == "video": # video # 2x2 pooling num_of_frames = image_feature.shape[0] image_feature = image_feature.view( diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index cd7526b0d93..f1195aff7c6 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -832,6 +832,7 @@ def v1_chat_generate_request( return_logprobs = [] logprob_start_lens = [] top_logprobs_nums = [] + modalities_list = [] # NOTE: with openai API, the prompt's logprobs are always not computed @@ -864,10 +865,12 @@ def v1_chat_generate_request( ) stop = request.stop image_data = None + modalities = [] else: conv = generate_chat_conv(request, chat_template_name) prompt = conv.get_prompt() image_data = conv.image_data + modalities = conv.modalities stop = conv.stop_str or [] if request.stop: if isinstance(request.stop, str): @@ -880,6 +883,7 @@ def v1_chat_generate_request( prompt_ids = request.messages stop = request.stop image_data = None + modalities = [] input_ids.append(prompt_ids) return_logprobs.append(request.logprobs) logprob_start_lens.append(-1) @@ -901,6 +905,7 @@ def v1_chat_generate_request( } ) image_data_list.append(image_data) + modalities_list.extend(modalities) if len(all_requests) == 1: input_ids = input_ids[0] if isinstance(input_ids, str): @@ -912,6 +917,7 @@ def v1_chat_generate_request( return_logprobs = return_logprobs[0] logprob_start_lens = logprob_start_lens[0] top_logprobs_nums = top_logprobs_nums[0] + modalities_list = modalities_list[:1] else: if isinstance(input_ids[0], str): prompt_kwargs = {"text": input_ids} @@ -928,6 +934,7 @@ def v1_chat_generate_request( stream=all_requests[0].stream, return_text_in_logprobs=True, rid=request_ids, + modalities=modalities_list, ) if len(all_requests) == 1: return adapted_request, all_requests[0] diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 8073df7952e..5525cd88275 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -213,6 +213,7 @@ class ChatCompletionMessageContentImageURL(BaseModel): class ChatCompletionMessageContentImagePart(BaseModel): type: Literal["image_url"] image_url: ChatCompletionMessageContentImageURL + modalities: Optional[Literal["image", "multi-images", "video"]] = "image" ChatCompletionMessageContentPart = Union[ diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 4f764c09cd8..727f5774cad 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -140,12 +140,14 @@ def test_mult_images_chat_completion(self): "image_url": { "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/test/lang/example_image.png" }, + "modalities": "multi-images", }, { "type": "image_url", "image_url": { "url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png" }, + "modalities": "multi-images", }, { "type": "text", @@ -192,6 +194,7 @@ def prepare_video_messages(self, video_path): frame_format = { "type": "image_url", "image_url": {"url": "data:image/jpeg;base64,{}"}, + "modalities": "video", } for base64_frame in base64_frames: From c9b75917d577523ba1c1c581c638b9d2e94b777b Mon Sep 17 00:00:00 2001 From: Kai-Hsun Chen Date: Mon, 9 Sep 2024 02:14:25 -0700 Subject: [PATCH 3/5] [server] Passing `model_override_args` to `launch_server` via the CLI. (#1298) Signed-off-by: Kai-Hsun Chen --- benchmark/blog_v0_2/405b_sglang.sh | 2 +- python/sglang/bench_latency.py | 1 + python/sglang/launch_server.py | 12 ++++----- python/sglang/launch_server_llavavid.py | 11 +++----- python/sglang/srt/server_args.py | 34 +++++++++++++++++++++++++ test/srt/run_suite.py | 1 + test/srt/test_server_args.py | 24 +++++++++++++++++ test/srt/test_serving_latency.py | 2 +- 8 files changed, 71 insertions(+), 16 deletions(-) create mode 100644 test/srt/test_server_args.py diff --git a/benchmark/blog_v0_2/405b_sglang.sh b/benchmark/blog_v0_2/405b_sglang.sh index eae5e22060a..d31f8daf8eb 100644 --- a/benchmark/blog_v0_2/405b_sglang.sh +++ b/benchmark/blog_v0_2/405b_sglang.sh @@ -6,7 +6,7 @@ # wget https://huggingface.co/neuralmagic/Meta-Llama-3.1-8B-Instruct-quantized.w8a8/resolve/main/tokenizer_config.json # Launch sglang -# python -m sglang.launch_server --model ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87 +# python -m sglang.launch_server --model-path ~/llama-3.1-405b-fp8-dummy/ --load-format dummy --tp 8 --quant fp8 --disable-radix --mem-frac 0.87 # offline python3 -m sglang.bench_serving --backend sglang --dataset-name random --num-prompt 3000 --random-input 1024 --random-output 1024 > sglang_log11 diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index 9006b7150aa..6113495776d 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -480,6 +480,7 @@ def main(server_args, bench_args): if __name__ == "__main__": + # TODO(kevin85421): Make the parser setup unit testable. parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) BenchArgs.add_cli_args(parser) diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 1df64e848c8..06aa140d9bc 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -1,20 +1,18 @@ """Launch the inference server.""" -import argparse import os +import sys from sglang.srt.server import launch_server -from sglang.srt.server_args import ServerArgs +from sglang.srt.server_args import prepare_server_args from sglang.srt.utils import kill_child_process if __name__ == "__main__": - parser = argparse.ArgumentParser() - ServerArgs.add_cli_args(parser) - args = parser.parse_args() - server_args = ServerArgs.from_cli_args(args) + server_args = prepare_server_args(sys.argv[1:]) + model_override_args = server_args.json_model_override_args try: - launch_server(server_args) + launch_server(server_args, model_override_args=model_override_args) except Exception as e: raise e finally: diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py index 43eefef4efa..6b8d151ee1d 100644 --- a/python/sglang/launch_server_llavavid.py +++ b/python/sglang/launch_server_llavavid.py @@ -1,14 +1,11 @@ """Launch the inference server for Llava-video model.""" -import argparse +import sys -from sglang.srt.server import ServerArgs, launch_server +from sglang.srt.server import launch_server, prepare_server_args if __name__ == "__main__": - parser = argparse.ArgumentParser() - ServerArgs.add_cli_args(parser) - args = parser.parse_args() - server_args = ServerArgs.from_cli_args(args) + server_args = prepare_server_args(sys.argv[1:]) model_override_args = {} model_override_args["mm_spatial_pool_stride"] = 2 @@ -20,7 +17,7 @@ model_override_args["max_sequence_length"] = 4096 * 2 model_override_args["tokenizer_model_max_length"] = 4096 * 2 model_override_args["model_max_length"] = 4096 * 2 - if "34b" in args.model_path.lower(): + if "34b" in server_args.model_path.lower(): model_override_args["image_token_index"] = 64002 launch_server(server_args, model_override_args, None) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 8a56c02e162..e21f02108c9 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -17,6 +17,7 @@ import argparse import dataclasses +import json import logging import random from typing import List, Optional, Union @@ -95,6 +96,9 @@ class ServerArgs: nnodes: int = 1 node_rank: Optional[int] = None + # Model override args in JSON + json_model_override_args: Optional[dict] = None + def __post_init__(self): if self.tokenizer_path is None: self.tokenizer_path = self.model_path @@ -455,10 +459,22 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", ) + # Model override args + parser.add_argument( + "--json-model-override-args", + type=str, + help="A dictionary in JSON string format used to override default model configurations.", + ) + @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size args.dp_size = args.data_parallel_size + args.json_model_override_args = ( + json.loads(args.json_model_override_args) + if args.json_model_override_args + else None + ) attrs = [attr.name for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs}) @@ -482,6 +498,24 @@ def check_server_args(self): self.disable_flashinfer = False +def prepare_server_args(args: argparse.Namespace) -> ServerArgs: + """ + Prepare the server arguments from the command line arguments. + + Args: + args: The command line arguments. Typically, it should be `sys.argv[1:]` + to ensure compatibility with `parse_args` when no arguments are passed. + + Returns: + The server arguments. + """ + parser = argparse.ArgumentParser() + ServerArgs.add_cli_args(parser) + raw_args = parser.parse_args(args) + server_args = ServerArgs.from_cli_args(raw_args) + return server_args + + @dataclasses.dataclass class PortArgs: tokenizer_port: int diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index cafcf3f2d59..d5982844ce3 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -19,6 +19,7 @@ "test_triton_attn_backend.py", "test_update_weights.py", "test_vision_openai_server.py", + "test_server_args.py", ], "sampling/penaltylib": glob.glob( "sampling/penaltylib/**/test_*.py", recursive=True diff --git a/test/srt/test_server_args.py b/test/srt/test_server_args.py new file mode 100644 index 00000000000..71129e3eb15 --- /dev/null +++ b/test/srt/test_server_args.py @@ -0,0 +1,24 @@ +import unittest + +from sglang.srt.server_args import prepare_server_args + + +class TestPrepareServerArgs(unittest.TestCase): + def test_prepare_server_args(self): + server_args = prepare_server_args( + [ + "--model-path", + "model_path", + "--json-model-override-args", + '{"rope_scaling": {"factor": 2.0, "type": "linear"}}', + ] + ) + self.assertEqual(server_args.model_path, "model_path") + self.assertEqual( + server_args.json_model_override_args, + {"rope_scaling": {"factor": 2.0, "type": "linear"}}, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_serving_latency.py b/test/srt/test_serving_latency.py index e762892c8eb..3dae4541a08 100644 --- a/test/srt/test_serving_latency.py +++ b/test/srt/test_serving_latency.py @@ -12,7 +12,7 @@ def test_default(self): "python3", "-m", "sglang.bench_latency", - "--model", + "--model-path", DEFAULT_MODEL_NAME_FOR_TEST, "--batch-size", "1", From e4d68afcf00869a5467f101d176fecc3cd97b7b8 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 9 Sep 2024 04:14:11 -0700 Subject: [PATCH 4/5] [Minor] Many cleanup (#1357) --- benchmark/gsm8k/README.md | 5 - benchmark/gsm8k/bench_other.py | 30 ++-- benchmark/gsm8k/bench_sglang.py | 39 +++-- benchmark/gsm8k/download_data.sh | 2 - benchmark/hellaswag/README.md | 5 - benchmark/hellaswag/bench_other.py | 23 +-- benchmark/hellaswag/bench_sglang.py | 24 +-- .../usage/llava_video/srt_example_llava_v.py | 3 +- python/sglang/bench_serving.py | 71 ++++---- python/sglang/launch_server.py | 3 +- python/sglang/launch_server_llavavid.py | 4 +- python/sglang/srt/constrained/fsm_cache.py | 67 ++++---- .../sglang/srt/managers/controller_multi.py | 6 +- .../sglang/srt/managers/controller_single.py | 5 - .../sglang/srt/managers/tokenizer_manager.py | 4 +- python/sglang/srt/managers/tp_worker.py | 157 +++++++++--------- .../sglang/srt/model_executor/model_runner.py | 1 + python/sglang/srt/server.py | 9 +- python/sglang/srt/server_args.py | 40 ++--- python/sglang/test/few_shot_gsm8k.py | 132 +++++++++++++++ python/sglang/test/test_programs.py | 12 +- python/sglang/utils.py | 69 ++++---- test/srt/test_moe_eval_accuracy_large.py | 4 +- test/srt/test_server_args.py | 3 +- 24 files changed, 419 insertions(+), 299 deletions(-) delete mode 100755 benchmark/gsm8k/download_data.sh create mode 100644 python/sglang/test/few_shot_gsm8k.py diff --git a/benchmark/gsm8k/README.md b/benchmark/gsm8k/README.md index a7dc04d9a9c..c110f533c79 100644 --- a/benchmark/gsm8k/README.md +++ b/benchmark/gsm8k/README.md @@ -1,8 +1,3 @@ -## Download data -``` -bash download_data.sh -``` - ## Run benchmark ### Benchmark sglang diff --git a/benchmark/gsm8k/bench_other.py b/benchmark/gsm8k/bench_other.py index 2a938d6bb9c..a8bbcfb5c19 100644 --- a/benchmark/gsm8k/bench_other.py +++ b/benchmark/gsm8k/bench_other.py @@ -10,7 +10,7 @@ from tqdm import tqdm from sglang.test.test_utils import add_common_other_args_and_parse, get_call_generate -from sglang.utils import dump_state_text, read_jsonl +from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl INVALID = -9999999 @@ -41,24 +41,28 @@ def get_answer_value(answer_str): def main(args): - lines = read_jsonl(args.data_path) + # Select backend + call_generate = get_call_generate(args) + + # Read data + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) # Construct prompts - k = args.num_shot - few_shot_examples = get_few_shot_examples(lines, k) + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) questions = [] labels = [] - for i in range(len(lines[: args.num_questions])): + for i in range(len(lines[:num_questions])): questions.append(get_one_example(lines, i, False)) labels.append(get_answer_value(lines[i]["answer"])) assert all(l != INVALID for l in labels) states = [None] * len(labels) - # Select backend - call_generate = get_call_generate(args) - # Run requests if args.backend != "lmql": # Use thread pool @@ -113,11 +117,13 @@ async def batched_call(batch_size): # Compute accuracy acc = np.mean(np.array(preds) == np.array(labels)) invalid = np.mean(np.array(preds) == INVALID) - print(f"Latency: {latency:.3f}") - print(f"Invalid: {invalid:.3f}") + + # Print results print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") - # Write results + # Dump results dump_state_text(f"tmp_output_{args.backend}.txt", states) with open(args.result_file, "a") as fout: @@ -138,7 +144,7 @@ async def batched_call(batch_size): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--num-shot", type=int, default=5) + parser.add_argument("--num-shots", type=int, default=5) parser.add_argument("--data-path", type=str, default="test.jsonl") parser.add_argument("--num-questions", type=int, default=200) args = add_common_other_args_and_parse(parser) diff --git a/benchmark/gsm8k/bench_sglang.py b/benchmark/gsm8k/bench_sglang.py index d32790fe0c8..9fe9b79baaf 100644 --- a/benchmark/gsm8k/bench_sglang.py +++ b/benchmark/gsm8k/bench_sglang.py @@ -6,11 +6,12 @@ import numpy as np +from sglang.api import set_default_backend from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, ) -from sglang.utils import dump_state_text, read_jsonl +from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl INVALID = -9999999 @@ -41,15 +42,22 @@ def get_answer_value(answer_str): def main(args): - lines = read_jsonl(args.data_path) + # Select backend + set_default_backend(select_sglang_backend(args)) + + # Read data + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) # Construct prompts - k = args.num_shot - few_shot_examples = get_few_shot_examples(lines, k) + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) questions = [] labels = [] - for i in range(len(lines[: args.num_questions])): + for i in range(len(lines[:num_questions])): questions.append(get_one_example(lines, i, False)) labels.append(get_answer_value(lines[i]["answer"])) assert all(l != INVALID for l in labels) @@ -72,15 +80,11 @@ def few_shot_gsm8k(s, question): ########## SGL Program End ########## ##################################### - # Select backend - backend = select_sglang_backend(args) - # Run requests tic = time.time() states = few_shot_gsm8k.run_batch( arguments, temperature=0, - backend=backend, num_threads=args.parallel, progress_bar=True, ) @@ -96,11 +100,20 @@ def few_shot_gsm8k(s, question): # Compute accuracy acc = np.mean(np.array(preds) == np.array(labels)) invalid = np.mean(np.array(preds) == INVALID) - print(f"Latency: {latency:.3f}") - print(f"Invalid: {invalid:.3f}") + + # Compute speed + num_output_tokens = sum( + s.get_meta_info("answer")["completion_tokens"] for s in states + ) + output_throughput = num_output_tokens / latency + + # Print results print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") + print(f"Output throughput: {output_throughput:.3f} token/s") - # Write results + # Dump results dump_state_text(f"tmp_output_{args.backend}.txt", states) with open(args.result_file, "a") as fout: @@ -121,7 +134,7 @@ def few_shot_gsm8k(s, question): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--num-shot", type=int, default=5) + parser.add_argument("--num-shots", type=int, default=5) parser.add_argument("--data-path", type=str, default="test.jsonl") parser.add_argument("--num-questions", type=int, default=200) args = add_common_sglang_args_and_parse(parser) diff --git a/benchmark/gsm8k/download_data.sh b/benchmark/gsm8k/download_data.sh deleted file mode 100755 index a9aa7756d2c..00000000000 --- a/benchmark/gsm8k/download_data.sh +++ /dev/null @@ -1,2 +0,0 @@ -wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/train.jsonl -wget https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl \ No newline at end of file diff --git a/benchmark/hellaswag/README.md b/benchmark/hellaswag/README.md index b3e7abc30fd..cb7e65366f9 100644 --- a/benchmark/hellaswag/README.md +++ b/benchmark/hellaswag/README.md @@ -1,8 +1,3 @@ -## Download data -``` -wget https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl -``` - ## Run benchmark ### Benchmark sglang diff --git a/benchmark/hellaswag/bench_other.py b/benchmark/hellaswag/bench_other.py index 5b9ba797bdc..04be4569a90 100644 --- a/benchmark/hellaswag/bench_other.py +++ b/benchmark/hellaswag/bench_other.py @@ -8,7 +8,7 @@ from tqdm import tqdm from sglang.test.test_utils import add_common_other_args_and_parse, get_call_select -from sglang.utils import read_jsonl +from sglang.utils import download_and_cache_file, read_jsonl def get_one_example(lines, i, include_answer): @@ -26,25 +26,29 @@ def get_few_shot_examples(lines, k): def main(args): - lines = read_jsonl(args.data_path) + # Select backend + call_select = get_call_select(args) + + # Read data + url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) # Construct prompts - k = args.num_shot - few_shot_examples = get_few_shot_examples(lines, k) + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) questions = [] choices = [] labels = [] - for i in range(len(lines[: args.num_questions])): + for i in range(len(lines[:num_questions])): questions.append(get_one_example(lines, i, False)) choices.append(lines[i]["endings"]) labels.append(lines[i]["label"]) preds = [None] * len(labels) - # Select backend - call_select = get_call_select(args) - # Run requests if args.backend != "lmql": # Use thread pool @@ -65,7 +69,6 @@ def get_one_answer(i): total=len(questions), ) ) - else: # Use asyncio async def batched_call(batch_size): @@ -108,7 +111,7 @@ async def batched_call(batch_size): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--num-shot", type=int, default=20) + parser.add_argument("--num-shots", type=int, default=20) parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl") parser.add_argument("--num-questions", type=int, default=200) args = add_common_other_args_and_parse(parser) diff --git a/benchmark/hellaswag/bench_sglang.py b/benchmark/hellaswag/bench_sglang.py index 2ccf1aaee2b..f09d7256da9 100644 --- a/benchmark/hellaswag/bench_sglang.py +++ b/benchmark/hellaswag/bench_sglang.py @@ -4,11 +4,12 @@ import numpy as np +from sglang.api import set_default_backend from sglang.test.test_utils import ( add_common_sglang_args_and_parse, select_sglang_backend, ) -from sglang.utils import read_jsonl +from sglang.utils import download_and_cache_file, read_jsonl def get_one_example(lines, i, include_answer): @@ -26,16 +27,23 @@ def get_few_shot_examples(lines, k): def main(args): - lines = read_jsonl(args.data_path) + # Select backend + set_default_backend(select_sglang_backend(args)) + + # Read data + url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) # Construct prompts - k = args.num_shot - few_shot_examples = get_few_shot_examples(lines, k) + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) questions = [] choices = [] labels = [] - for i in range(len(lines[: args.num_questions])): + for i in range(len(lines[:num_questions])): questions.append(get_one_example(lines, i, False)) choices.append(lines[i]["endings"]) labels.append(lines[i]["label"]) @@ -56,15 +64,11 @@ def few_shot_hellaswag(s, question, choices): ########## SGL Program End ########## ##################################### - # Select backend - backend = select_sglang_backend(args) - # Run requests tic = time.time() rets = few_shot_hellaswag.run_batch( arguments, temperature=0, - backend=backend, num_threads=args.parallel, progress_bar=True, ) @@ -95,7 +99,7 @@ def few_shot_hellaswag(s, question, choices): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument("--num-shot", type=int, default=20) + parser.add_argument("--num-shots", type=int, default=20) parser.add_argument("--data-path", type=str, default="hellaswag_val.jsonl") parser.add_argument("--num-questions", type=int, default=200) args = add_common_sglang_args_and_parse(parser) diff --git a/examples/frontend_language/usage/llava_video/srt_example_llava_v.py b/examples/frontend_language/usage/llava_video/srt_example_llava_v.py index 02bab342ac5..c3b8da7d6ab 100644 --- a/examples/frontend_language/usage/llava_video/srt_example_llava_v.py +++ b/examples/frontend_language/usage/llava_video/srt_example_llava_v.py @@ -7,6 +7,7 @@ import argparse import csv +import json import os import time @@ -223,7 +224,7 @@ def batch(video_dir, save_dir, cur_chunk, num_chunks, num_frames=16, batch_size= tokenizer_path=tokenizer_path, port=cur_port, additional_ports=[cur_port + 1, cur_port + 2, cur_port + 3, cur_port + 4], - model_override_args=model_override_args, + json_model_override_args=json.dumps(model_override_args), tp_size=1, ) sgl.set_default_backend(runtime) diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 69d175d8437..d51aee4ec9f 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -298,34 +298,41 @@ class BenchmarkMetrics: median_e2e_latency_ms: float -default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json" +SHAREGPT_URL = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" -def download_sharegpt_dataset(path): - url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) - print(f"Downloading dataset from {url}") - try: - response = requests.get(url, stream=True) - response.raise_for_status() + # Check if the cache file already exists + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") - total_size = int(response.headers.get("content-length", 0)) - block_size = 8192 + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors - with open(path, "wb") as f, tqdm( - desc="Downloading", - total=total_size, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as progress_bar: - for data in response.iter_content(block_size): - size = f.write(data) - progress_bar.update(size) + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB - print(f"Dataset downloaded and saved to {path}") - except requests.RequestException as e: - raise Exception(f"Failed to download dataset: {e}") + # Use tqdm to display the progress bar + with open(filename, "wb") as f, tqdm( + desc=filename, + total=total_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + f.write(chunk) + bar.update(len(chunk)) + + return filename def sample_sharegpt_requests( @@ -338,13 +345,8 @@ def sample_sharegpt_requests( raise ValueError("output_len too small") # Download sharegpt if necessary - if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path): - download_sharegpt_dataset(default_sharegpt_path) - dataset_path = default_sharegpt_path - else: - dataset_path = ( - dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path - ) + if not os.path.isfile(dataset_path): + dataset_path = download_and_cache_file(SHAREGPT_URL) # Load the dataset. with open(dataset_path) as f: @@ -412,15 +414,8 @@ def sample_random_requests( # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens # Download sharegpt if necessary - if not os.path.isfile(dataset_path) and not os.path.isfile( - default_sharegpt_path - ): - download_sharegpt_dataset(default_sharegpt_path) - dataset_path = default_sharegpt_path - else: - dataset_path = ( - dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path - ) + if not os.path.isfile(dataset_path): + dataset_path = download_and_cache_file(SHAREGPT_URL) # Load the dataset. with open(dataset_path) as f: diff --git a/python/sglang/launch_server.py b/python/sglang/launch_server.py index 06aa140d9bc..ce4cb07c2b2 100644 --- a/python/sglang/launch_server.py +++ b/python/sglang/launch_server.py @@ -9,10 +9,9 @@ if __name__ == "__main__": server_args = prepare_server_args(sys.argv[1:]) - model_override_args = server_args.json_model_override_args try: - launch_server(server_args, model_override_args=model_override_args) + launch_server(server_args) except Exception as e: raise e finally: diff --git a/python/sglang/launch_server_llavavid.py b/python/sglang/launch_server_llavavid.py index 6b8d151ee1d..6816dcc112a 100644 --- a/python/sglang/launch_server_llavavid.py +++ b/python/sglang/launch_server_llavavid.py @@ -1,5 +1,6 @@ """Launch the inference server for Llava-video model.""" +import json import sys from sglang.srt.server import launch_server, prepare_server_args @@ -19,5 +20,6 @@ model_override_args["model_max_length"] = 4096 * 2 if "34b" in server_args.model_path.lower(): model_override_args["image_token_index"] = 64002 + server_args.json_model_override_args = json.dumps(model_override_args) - launch_server(server_args, model_override_args, None) + launch_server(server_args) diff --git a/python/sglang/srt/constrained/fsm_cache.py b/python/sglang/srt/constrained/fsm_cache.py index 57c49130622..fd5995dad1c 100644 --- a/python/sglang/srt/constrained/fsm_cache.py +++ b/python/sglang/srt/constrained/fsm_cache.py @@ -16,6 +16,7 @@ """Cache for the compressed finite state machine.""" from outlines.fsm.json_schema import build_regex_from_schema +from transformers import AutoTokenizer from sglang.srt.constrained import RegexGuide, TransformerTokenizer from sglang.srt.constrained.base_tool_cache import BaseToolCache @@ -28,12 +29,9 @@ def __init__( tokenizer_args_dict, enable=True, skip_tokenizer_init=False, - json_schema_mode=False, ): super().__init__(enable=enable) - self.json_schema_mode = json_schema_mode - if ( skip_tokenizer_init or tokenizer_path.endswith(".json") @@ -42,44 +40,37 @@ def __init__( # Do not support TiktokenTokenizer or SentencePieceTokenizer return - from importlib.metadata import version + tokenizer_args_dict.setdefault("padding_side", "left") + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict) + try: + self.outlines_tokenizer = TransformerTokenizer(tokenizer) + except AttributeError: + # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0) + origin_pad_token_id = tokenizer.pad_token_id - if version("outlines") >= "0.0.35": - from transformers import AutoTokenizer + def fset(self, value): + self._value = value - tokenizer_args_dict.setdefault("padding_side", "left") - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_path, **tokenizer_args_dict + type(tokenizer).pad_token_id = property( + fget=type(tokenizer).pad_token_id.fget, fset=fset ) - try: - self.outlines_tokenizer = TransformerTokenizer(tokenizer) - except AttributeError: - # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0) - origin_pad_token_id = tokenizer.pad_token_id - - def fset(self, value): - self._value = value - - type(tokenizer).pad_token_id = property( - fget=type(tokenizer).pad_token_id.fget, fset=fset - ) - self.outlines_tokenizer = TransformerTokenizer(tokenizer) - self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id - self.outlines_tokenizer.pad_token_id = origin_pad_token_id - self.outlines_tokenizer.pad_token = ( - self.outlines_tokenizer.tokenizer.pad_token - ) - self.outlines_tokenizer.vocabulary = ( - self.outlines_tokenizer.tokenizer.get_vocab() - ) - else: - self.outlines_tokenizer = TransformerTokenizer( - tokenizer_path, **tokenizer_args_dict + self.outlines_tokenizer = TransformerTokenizer(tokenizer) + self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id + self.outlines_tokenizer.pad_token_id = origin_pad_token_id + self.outlines_tokenizer.pad_token = ( + self.outlines_tokenizer.tokenizer.pad_token + ) + self.outlines_tokenizer.vocabulary = ( + self.outlines_tokenizer.tokenizer.get_vocab() ) - def init_value(self, value): - if self.json_schema_mode: - regex = build_regex_from_schema(value, whitespace_pattern=r"[\n\t ]*") - return RegexGuide(regex, self.outlines_tokenizer), regex + def init_value(self, key): + key_type, key_string = key + if key_type == "json": + regex = build_regex_from_schema(key_string, whitespace_pattern=r"[\n\t ]*") + elif key_type == "regex": + regex = key_string else: - return RegexGuide(value, self.outlines_tokenizer) + raise ValueError(f"Invalid key_type: {key_type}") + + return RegexGuide(regex, self.outlines_tokenizer), regex diff --git a/python/sglang/srt/managers/controller_multi.py b/python/sglang/srt/managers/controller_multi.py index ba626d4cffc..e4b316155a4 100644 --- a/python/sglang/srt/managers/controller_multi.py +++ b/python/sglang/srt/managers/controller_multi.py @@ -71,12 +71,10 @@ def __init__( self, server_args: ServerArgs, port_args: PortArgs, - model_override_args, ): # Parse args self.server_args = server_args self.port_args = port_args - self.model_override_args = model_override_args self.load_balance_method = LoadBalanceMethod.from_str( server_args.load_balance_method ) @@ -114,7 +112,6 @@ def start_dp_worker(self, dp_worker_id: int): self.server_args, self.port_args, pipe_controller_writer, - self.model_override_args, True, gpu_ids, dp_worker_id, @@ -189,14 +186,13 @@ def start_controller_process( server_args: ServerArgs, port_args: PortArgs, pipe_writer, - model_override_args: dict, ): """Start a controller process.""" configure_logger(server_args) try: - controller = ControllerMulti(server_args, port_args, model_override_args) + controller = ControllerMulti(server_args, port_args) except Exception: pipe_writer.send(get_exception_traceback()) raise diff --git a/python/sglang/srt/managers/controller_single.py b/python/sglang/srt/managers/controller_single.py index 2ae37059c10..fe03ca1d476 100644 --- a/python/sglang/srt/managers/controller_single.py +++ b/python/sglang/srt/managers/controller_single.py @@ -40,7 +40,6 @@ def __init__( self, server_args: ServerArgs, port_args: PortArgs, - model_override_args: dict, gpu_ids: List[int], is_data_parallel_worker: bool, dp_worker_id: int, @@ -76,7 +75,6 @@ def __init__( tp_rank_range, server_args, port_args.nccl_ports[dp_worker_id], - model_override_args, ) # Launch tp rank 0 @@ -85,7 +83,6 @@ def __init__( 0, server_args, port_args.nccl_ports[dp_worker_id], - model_override_args, ) self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group @@ -126,7 +123,6 @@ def start_controller_process( server_args: ServerArgs, port_args: PortArgs, pipe_writer: multiprocessing.connection.Connection, - model_override_args: dict, is_data_parallel_worker: bool = False, gpu_ids: List[int] = None, dp_worker_id: int = None, @@ -149,7 +145,6 @@ def start_controller_process( controller = ControllerSingle( server_args, port_args, - model_override_args, gpu_ids, is_data_parallel_worker, dp_worker_id, diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index d0cfed08cd1..d2fa6760129 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -18,6 +18,7 @@ import asyncio import concurrent.futures import dataclasses +import json import logging import multiprocessing as mp import os @@ -77,7 +78,6 @@ def __init__( self, server_args: ServerArgs, port_args: PortArgs, - model_override_args: dict = None, ): self.server_args = server_args @@ -95,7 +95,7 @@ def __init__( self.hf_config = get_config( self.model_path, trust_remote_code=server_args.trust_remote_code, - model_override_args=model_override_args, + model_override_args=json.loads(server_args.json_model_override_args), ) self.is_generation = is_generation_model( self.hf_config.architectures, self.server_args.is_embedding diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index 7bb9c433565..513bc517f58 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -15,13 +15,14 @@ """A tensor parallel worker.""" +import json import logging import multiprocessing import os import pickle import time import warnings -from typing import Any, List, Optional, Union +from typing import Any, List, Optional import torch import torch.distributed @@ -66,6 +67,7 @@ logger = logging.getLogger(__name__) +# Crash on warning if we are running CI tests crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true" @@ -76,11 +78,10 @@ def __init__( tp_rank: int, server_args: ServerArgs, nccl_port: int, - model_override_args: dict, ): suppress_other_loggers() - # Copy arguments + # Parse arguments self.gpu_id = gpu_id self.tp_rank = tp_rank self.tp_size = server_args.tp_size @@ -93,9 +94,8 @@ def __init__( server_args.model_path, server_args.trust_remote_code, context_length=server_args.context_length, - model_override_args=model_override_args, + model_override_args=json.loads(server_args.json_model_override_args), ) - self.model_runner = ModelRunner( model_config=self.model_config, mem_fraction_static=server_args.mem_fraction_static, @@ -136,7 +136,7 @@ def __init__( self.max_total_num_tokens - 1, ) - # Sync random seed + # Sync random seed across TP workers server_args.random_seed = broadcast_recv_input( [server_args.random_seed], self.tp_rank, @@ -144,7 +144,7 @@ def __init__( )[0] set_random_seed(server_args.random_seed) - # Print info + # Print debug info logger.info( f"max_total_num_tokens={self.max_total_num_tokens}, " f"max_prefill_tokens={self.max_prefill_tokens}, " @@ -181,7 +181,7 @@ def __init__( self.num_generated_tokens = 0 self.last_stats_tic = time.time() - # Chunked prefill + # Init chunked prefill self.chunked_prefill_size = server_args.chunked_prefill_size self.current_inflight_req = None self.is_mixed_chunk = ( @@ -197,16 +197,6 @@ def __init__( "trust_remote_code": server_args.trust_remote_code, }, skip_tokenizer_init=server_args.skip_tokenizer_init, - json_schema_mode=False, - ) - self.json_fsm_cache = FSMCache( - server_args.tokenizer_path, - { - "tokenizer_mode": server_args.tokenizer_mode, - "trust_remote_code": server_args.trust_remote_code, - }, - skip_tokenizer_init=server_args.skip_tokenizer_init, - json_schema_mode=True, ) self.jump_forward_cache = JumpForwardCache() @@ -227,11 +217,12 @@ def exposed_step(self, recv_reqs: List): try: # Recv requests for recv_req in recv_reqs: - if isinstance( - recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput) - ): + if isinstance(recv_req, TokenizedGenerateReqInput): self.handle_generate_request(recv_req) self.do_not_get_new_batch = False + elif isinstance(recv_req, TokenizedEmbeddingReqInput): + self.handle_embedding_request(recv_req) + self.do_not_get_new_batch = False elif isinstance(recv_req, FlushCacheReq): self.flush_cache() elif isinstance(recv_req, AbortReq): @@ -331,57 +322,56 @@ def check_memory(self): def handle_generate_request( self, - recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput], + recv_req: TokenizedGenerateReqInput, ): req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) req.tokenizer = self.tokenizer req.sampling_params = recv_req.sampling_params - if self.model_runner.is_generation: - req.pixel_values = recv_req.pixel_values - if req.pixel_values is not None: - # Use image hash as fake token_ids, which is then used - # for prefix matching - image_hash = hash(tuple(recv_req.image_hashes)) - req.pad_value = [ - (image_hash) % self.model_config.vocab_size, - (image_hash >> 16) % self.model_config.vocab_size, - (image_hash >> 32) % self.model_config.vocab_size, - (image_hash >> 64) % self.model_config.vocab_size, - ] - req.image_sizes = recv_req.image_sizes - ( - req.origin_input_ids, - req.image_offsets, - ) = self.model_runner.model.pad_input_ids( - req.origin_input_ids_unpadded, - req.pad_value, - req.pixel_values, - req.image_sizes, - ) - # Only when pixel values is not None we have modalities - req.modalities = recv_req.modalites - req.return_logprob = recv_req.return_logprob - req.logprob_start_len = recv_req.logprob_start_len - req.top_logprobs_num = recv_req.top_logprobs_num - req.stream = recv_req.stream - - # Init regex fsm fron json + req.pixel_values = recv_req.pixel_values + if req.pixel_values is not None: + # Use image hash as fake token_ids, which is then used + # for prefix matching + image_hash = hash(tuple(recv_req.image_hashes)) + req.pad_value = [ + (image_hash) % self.model_config.vocab_size, + (image_hash >> 16) % self.model_config.vocab_size, + (image_hash >> 32) % self.model_config.vocab_size, + (image_hash >> 64) % self.model_config.vocab_size, + ] + req.image_sizes = recv_req.image_sizes + ( + req.origin_input_ids, + req.image_offsets, + ) = self.model_runner.model.pad_input_ids( + req.origin_input_ids_unpadded, + req.pad_value, + req.pixel_values, + req.image_sizes, + ) + # Only when pixel values is not None we have modalities + req.modalities = recv_req.modalites + req.return_logprob = recv_req.return_logprob + req.logprob_start_len = recv_req.logprob_start_len + req.top_logprobs_num = recv_req.top_logprobs_num + req.stream = recv_req.stream + + # Init regex FSM + if ( + req.sampling_params.json_schema is not None + or req.sampling_params.regex is not None + ): if req.sampling_params.json_schema is not None: - req.regex_fsm, computed_regex_string = self.json_fsm_cache.query( - req.sampling_params.json_schema + req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( + ("json", req.sampling_params.json_schema) ) - if not self.disable_regex_jump_forward: - req.jump_forward_map = self.jump_forward_cache.query( - computed_regex_string - ) - - # Init regex fsm elif req.sampling_params.regex is not None: - req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex) - if not self.disable_regex_jump_forward: - req.jump_forward_map = self.jump_forward_cache.query( - req.sampling_params.regex - ) + req.regex_fsm, computed_regex_string = self.regex_fsm_cache.query( + ("regex", req.sampling_params.regex) + ) + if not self.disable_regex_jump_forward: + req.jump_forward_map = self.jump_forward_cache.query( + computed_regex_string + ) # Truncate prompts that are too long if len(req.origin_input_ids) >= self.max_req_input_len: @@ -390,16 +380,32 @@ def handle_generate_request( "the max context length. Truncated!!!" ) req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] + req.sampling_params.max_new_tokens = min( + ( + req.sampling_params.max_new_tokens + if req.sampling_params.max_new_tokens is not None + else 1 << 30 + ), + self.max_req_input_len - 1 - len(req.origin_input_ids), + ) - if self.model_runner.is_generation: - req.sampling_params.max_new_tokens = min( - ( - req.sampling_params.max_new_tokens - if req.sampling_params.max_new_tokens is not None - else 1 << 30 - ), - self.max_req_input_len - 1 - len(req.origin_input_ids), + self.waiting_queue.append(req) + + def handle_embedding_request( + self, + recv_req: TokenizedEmbeddingReqInput, + ): + req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids) + req.tokenizer = self.tokenizer + req.sampling_params = recv_req.sampling_params + + # Truncate prompts that are too long + if len(req.origin_input_ids) >= self.max_req_input_len: + logger.warn( + "Request length is longer than the KV cache pool size or " + "the max context length. Truncated!!!" ) + req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len] self.waiting_queue.append(req) @@ -892,7 +898,6 @@ def run_tp_server( tp_rank: int, server_args: ServerArgs, nccl_port: int, - model_override_args: dict, ): """Run a tensor parallel model server.""" configure_logger(server_args, prefix=f" TP{tp_rank}") @@ -903,7 +908,6 @@ def run_tp_server( tp_rank, server_args, nccl_port, - model_override_args, ) tp_cpu_group = model_server.model_runner.tp_group.cpu_group @@ -920,14 +924,13 @@ def launch_tp_servers( tp_rank_range: List[int], server_args: ServerArgs, nccl_port: int, - model_override_args: dict, ): """Launch multiple tensor parallel servers.""" procs = [] for i in tp_rank_range: proc = multiprocessing.Process( target=run_tp_server, - args=(gpu_ids[i], i, server_args, nccl_port, model_override_args), + args=(gpu_ids[i], i, server_args, nccl_port), ) proc.start() procs.append(proc) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 3d3e0cde9d1..9c82b2a813a 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -18,6 +18,7 @@ import gc import importlib import importlib.resources +import json import logging import pkgutil from functools import lru_cache diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index feaf91dd390..d44d6175220 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -272,7 +272,6 @@ async def retrieve_file_content(file_id: str): def launch_server( server_args: ServerArgs, - model_override_args: Optional[dict] = None, pipe_finish_writer: Optional[mp.connection.Connection] = None, ): """Launch an HTTP server.""" @@ -317,7 +316,6 @@ def launch_server( tp_rank_range, server_args, ports[3], - model_override_args, ) try: @@ -328,7 +326,7 @@ def launch_server( return # Launch processes - tokenizer_manager = TokenizerManager(server_args, port_args, model_override_args) + tokenizer_manager = TokenizerManager(server_args, port_args) if server_args.chat_template: load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template) pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False) @@ -341,7 +339,7 @@ def launch_server( proc_controller = mp.Process( target=start_controller_process, - args=(server_args, port_args, pipe_controller_writer, model_override_args), + args=(server_args, port_args, pipe_controller_writer), ) proc_controller.start() @@ -501,7 +499,6 @@ class Runtime: def __init__( self, log_level: str = "error", - model_override_args: Optional[dict] = None, *args, **kwargs, ): @@ -525,7 +522,7 @@ def __init__( proc = mp.Process( target=launch_server, - args=(self.server_args, model_override_args, pipe_writer), + args=(self.server_args, pipe_writer), ) proc.start() pipe_writer.close() diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e21f02108c9..14dd63b5adb 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -76,6 +76,14 @@ class ServerArgs: dp_size: int = 1 load_balance_method: str = "round_robin" + # Distributed args + nccl_init_addr: Optional[str] = None + nnodes: int = 1 + node_rank: Optional[int] = None + + # Model override args in JSON + json_model_override_args: str = "{}" + # Optimization/debug options disable_flashinfer: bool = False disable_flashinfer_sampling: bool = False @@ -91,14 +99,6 @@ class ServerArgs: enable_mla: bool = False triton_attention_reduce_in_fp32: bool = False - # Distributed args - nccl_init_addr: Optional[str] = None - nnodes: int = 1 - node_rank: Optional[int] = None - - # Model override args in JSON - json_model_override_args: Optional[dict] = None - def __post_init__(self): if self.tokenizer_path is None: self.tokenizer_path = self.model_path @@ -385,6 +385,14 @@ def add_cli_args(parser: argparse.ArgumentParser): ) parser.add_argument("--node-rank", type=int, help="The node rank.") + # Model override args + parser.add_argument( + "--json-model-override-args", + type=str, + help="A dictionary in JSON string format used to override default model configurations.", + default=ServerArgs.json_model_override_args, + ) + # Optimization/debug options parser.add_argument( "--disable-flashinfer", @@ -459,22 +467,10 @@ def add_cli_args(parser: argparse.ArgumentParser): help="Turn on memory efficient weight loading with quantization (quantize per layer during loading).", ) - # Model override args - parser.add_argument( - "--json-model-override-args", - type=str, - help="A dictionary in JSON string format used to override default model configurations.", - ) - @classmethod def from_cli_args(cls, args: argparse.Namespace): args.tp_size = args.tensor_parallel_size args.dp_size = args.data_parallel_size - args.json_model_override_args = ( - json.loads(args.json_model_override_args) - if args.json_model_override_args - else None - ) attrs = [attr.name for attr in dataclasses.fields(cls)] return cls(**{attr: getattr(args, attr) for attr in attrs}) @@ -498,7 +494,7 @@ def check_server_args(self): self.disable_flashinfer = False -def prepare_server_args(args: argparse.Namespace) -> ServerArgs: +def prepare_server_args(argv: List[str]) -> ServerArgs: """ Prepare the server arguments from the command line arguments. @@ -511,7 +507,7 @@ def prepare_server_args(args: argparse.Namespace) -> ServerArgs: """ parser = argparse.ArgumentParser() ServerArgs.add_cli_args(parser) - raw_args = parser.parse_args(args) + raw_args = parser.parse_args(argv) server_args = ServerArgs.from_cli_args(raw_args) return server_args diff --git a/python/sglang/test/few_shot_gsm8k.py b/python/sglang/test/few_shot_gsm8k.py new file mode 100644 index 00000000000..18ae2d8c350 --- /dev/null +++ b/python/sglang/test/few_shot_gsm8k.py @@ -0,0 +1,132 @@ +""" +Run few-shot GSM-8K evaluation. + +Usage: +python3 -m sglang.test.few_shot_gsm8k --num-questions 200 +""" + +import argparse +import ast +import re +import time + +import numpy as np + +from sglang.api import set_default_backend +from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint +from sglang.utils import download_and_cache_file, dump_state_text, read_jsonl + +INVALID = -9999999 + + +def get_one_example(lines, i, include_answer): + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def main(args): + # Select backend + set_default_backend(RuntimeEndpoint(f"{args.host}:{args.port}")) + + # Read data + url = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) + + # Construct prompts + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) + + questions = [] + labels = [] + for i in range(len(lines[:num_questions])): + questions.append(get_one_example(lines, i, False)) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(l != INVALID for l in labels) + arguments = [{"question": q} for q in questions] + + ##################################### + ######### SGL Program Begin ######### + ##################################### + + import sglang as sgl + + @sgl.function + def few_shot_gsm8k(s, question): + s += few_shot_examples + question + s += sgl.gen( + "answer", max_tokens=512, stop=["Question", "Assistant:", "<|separator|>"] + ) + + ##################################### + ########## SGL Program End ########## + ##################################### + + # Run requests + tic = time.time() + states = few_shot_gsm8k.run_batch( + arguments, + temperature=0, + num_threads=args.parallel, + progress_bar=True, + ) + latency = time.time() - tic + + preds = [] + for i in range(len(states)): + preds.append(get_answer_value(states[i]["answer"])) + + # print(f"{preds=}") + # print(f"{labels=}") + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + + # Compute speed + num_output_tokens = sum( + s.get_meta_info("answer")["completion_tokens"] for s in states + ) + output_throughput = num_output_tokens / latency + + # Print results + print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") + print(f"Output throughput: {output_throughput:.3f} token/s") + + # Dump results + dump_state_text("tmp_output_gsm8k.txt", states) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-shots", type=int, default=5) + parser.add_argument("--data-path", type=str, default="test.jsonl") + parser.add_argument("--num-questions", type=int, default=200) + parser.add_argument("--parallel", type=int, default=128) + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=30000) + args = parser.parse_args() + main(args) diff --git a/python/sglang/test/test_programs.py b/python/sglang/test/test_programs.py index bdecdff2f94..41f466f7307 100644 --- a/python/sglang/test/test_programs.py +++ b/python/sglang/test/test_programs.py @@ -7,7 +7,7 @@ import numpy as np import sglang as sgl -from sglang.utils import fetch_and_cache_jsonl +from sglang.utils import download_and_cache_file, read_jsonl def test_few_shot_qa(): @@ -456,10 +456,6 @@ def gen_character_spec(s): def test_hellaswag_select(): """Benchmark the accuracy of sgl.select on the HellaSwag dataset.""" - url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl" - lines = fetch_and_cache_jsonl(url) - - # Construct prompts def get_one_example(lines, i, include_answer): ret = lines[i]["activity_label"] + ": " + lines[i]["ctx"] + " " if include_answer: @@ -472,6 +468,12 @@ def get_few_shot_examples(lines, k): ret += get_one_example(lines, i, True) + "\n\n" return ret + # Read data + url = "https://raw.githubusercontent.com/rowanz/hellaswag/master/data/hellaswag_val.jsonl" + filename = download_and_cache_file(url) + lines = list(read_jsonl(filename)) + + # Construct prompts num_questions = 200 num_shots = 20 few_shot_examples = get_few_shot_examples(lines, num_shots) diff --git a/python/sglang/utils.py b/python/sglang/utils.py index b212f6caa31..621efb5373c 100644 --- a/python/sglang/utils.py +++ b/python/sglang/utils.py @@ -12,7 +12,7 @@ from concurrent.futures import ThreadPoolExecutor from io import BytesIO from json import dumps -from typing import Union +from typing import Optional, Union import numpy as np import requests @@ -38,13 +38,11 @@ def is_same_type(values: list): def read_jsonl(filename: str): """Read a JSONL file.""" - rets = [] with open(filename) as fin: for line in fin: if line.startswith("#"): continue - rets.append(json.loads(line)) - return rets + yield json.loads(line) def dump_state_text(filename: str, states: list, mode: str = "w"): @@ -264,38 +262,35 @@ def __call__(self, *args, **kwargs): return module(*args, **kwargs) -def fetch_and_cache_jsonl(url, cache_file="cached_data.jsonl"): - """Read and cache a jsonl file from a url.""" +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) # Check if the cache file already exists - if os.path.exists(cache_file): - print("Loading data from cache...") - with open(cache_file, "r") as f: - data = [json.loads(line) for line in f] - else: - print("Downloading data from URL...") - # Stream the response to show the progress bar - response = requests.get(url, stream=True) - response.raise_for_status() # Check for request errors - - # Total size of the file in bytes - total_size = int(response.headers.get("content-length", 0)) - chunk_size = 1024 # Download in chunks of 1KB - - # Use tqdm to display the progress bar - with open(cache_file, "wb") as f, tqdm( - desc=cache_file, - total=total_size, - unit="B", - unit_scale=True, - unit_divisor=1024, - ) as bar: - for chunk in response.iter_content(chunk_size=chunk_size): - f.write(chunk) - bar.update(len(chunk)) - - # Convert the data to a list of dictionaries - with open(cache_file, "r") as f: - data = [json.loads(line) for line in f] - - return data + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") + + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors + + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(filename, "wb") as f, tqdm( + desc=filename, + total=total_size, + unit="B", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + f.write(chunk) + bar.update(len(chunk)) + + return filename diff --git a/test/srt/test_moe_eval_accuracy_large.py b/test/srt/test_moe_eval_accuracy_large.py index d4b1354b793..b15308dcec2 100644 --- a/test/srt/test_moe_eval_accuracy_large.py +++ b/test/srt/test_moe_eval_accuracy_large.py @@ -42,7 +42,7 @@ def test_mmlu(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.63, f"{metrics}" + assert metrics["score"] >= 0.62, f"{metrics}" def test_human_eval(self): args = SimpleNamespace( @@ -66,7 +66,7 @@ def test_mgsm_en(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.63, f"{metrics}" + assert metrics["score"] >= 0.62, f"{metrics}" if __name__ == "__main__": diff --git a/test/srt/test_server_args.py b/test/srt/test_server_args.py index 71129e3eb15..d8f31ce1b96 100644 --- a/test/srt/test_server_args.py +++ b/test/srt/test_server_args.py @@ -1,3 +1,4 @@ +import json import unittest from sglang.srt.server_args import prepare_server_args @@ -15,7 +16,7 @@ def test_prepare_server_args(self): ) self.assertEqual(server_args.model_path, "model_path") self.assertEqual( - server_args.json_model_override_args, + json.loads(server_args.json_model_override_args), {"rope_scaling": {"factor": 2.0, "type": "linear"}}, ) From a7c47e0f028c2a9e67cbc99ab67692ec765d3dd0 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Mon, 9 Sep 2024 05:32:41 -0700 Subject: [PATCH 5/5] Add torchao quant (int4/int8/fp8) to llama models (#1341) Co-authored-by: Lianmin Zheng --- python/pyproject.toml | 2 +- python/sglang/srt/layers/torchao_utils.py | 36 +++++++++ .../sglang/srt/model_executor/model_runner.py | 1 + python/sglang/srt/models/llama.py | 22 ++++++ python/sglang/srt/server_args.py | 9 ++- test/srt/test_eval_accuracy_mini.py | 4 +- test/srt/test_moe_eval_accuracy_large.py | 6 +- test/srt/test_torch_compile.py | 6 +- test/srt/test_torchao.py | 73 +++++++++++++++++++ test/srt/test_triton_attn_backend.py | 4 +- 10 files changed, 151 insertions(+), 12 deletions(-) create mode 100644 python/sglang/srt/layers/torchao_utils.py create mode 100644 test/srt/test_torchao.py diff --git a/python/pyproject.toml b/python/pyproject.toml index daf09ea25de..1389822a34b 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ [project.optional-dependencies] srt = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hub", "interegular", "packaging", "pillow", "psutil", "pydantic", "python-multipart", - "torch", "uvicorn", "uvloop", "zmq", + "torch", "torchao", "uvicorn", "uvloop", "zmq", "vllm==0.5.5", "outlines>=0.0.44"] openai = ["openai>=1.0", "tiktoken"] anthropic = ["anthropic>=0.20.0"] diff --git a/python/sglang/srt/layers/torchao_utils.py b/python/sglang/srt/layers/torchao_utils.py new file mode 100644 index 00000000000..16eb1f2c5c6 --- /dev/null +++ b/python/sglang/srt/layers/torchao_utils.py @@ -0,0 +1,36 @@ +""" +Common utilities for torchao. +""" + +import torch +from torchao.quantization import ( + int4_weight_only, + int8_dynamic_activation_int8_weight, + int8_weight_only, + quantize_, +) + + +def torchao_quantize_param_data(param, torchao_config): + dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False) + dummy_linear.weight = param + if "int8wo" in torchao_config: + quantize_(dummy_linear, int8_weight_only()) + elif "int8dq" in torchao_config: + quantize_(dummy_linear, int8_dynamic_activation_int8_weight()) + elif "int4wo" in torchao_config: + group_size = int(torchao_config.split("-")[-1]) + assert group_size in [ + 32, + 64, + 128, + 256, + ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}" + quantize_(dummy_linear, int4_weight_only(group_size=group_size)) + elif "fp8wo" in torchao_config: + from torchao.quantization import float8_weight_only + + # this requires newer hardware + # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 + quantize_(dummy_linear, float8_weight_only()) + return dummy_linear.weight diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 9c82b2a813a..78f99dcd67b 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -97,6 +97,7 @@ def __init__( "disable_flashinfer_sampling": server_args.disable_flashinfer_sampling, "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32, "enable_mla": server_args.enable_mla, + "torchao_config": server_args.torchao_config, } ) diff --git a/python/sglang/srt/models/llama.py b/python/sglang/srt/models/llama.py index 926d87db8b7..ac53712fca4 100644 --- a/python/sglang/srt/models/llama.py +++ b/python/sglang/srt/models/llama.py @@ -42,6 +42,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.sampler import Sampler +from sglang.srt.layers.torchao_utils import torchao_quantize_param_data +from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -299,6 +301,7 @@ def __init__( super().__init__() self.config = config self.quant_config = quant_config + self.torchao_config = global_server_args_dict["torchao_config"] self.model = LlamaModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) @@ -361,6 +364,25 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + if self.torchao_config: + if name.endswith("proj.weight") and param.ndim == 2: + params_dict[name] = torchao_quantize_param_data( + param, self.torchao_config + ) + + if self.torchao_config: + # quantizing the loaded, stacked params, e.g. "...qkv_proj" + stacked_params = set(entry[0] for entry in stacked_params_mapping) + for param_suffix in stacked_params: + for name in params_dict: + if param_suffix in name: + param = params_dict[name] + params_dict[name] = torchao_quantize_param_data( + param, self.torchao_config + ) + + self.load_state_dict(params_dict, assign=True) + class Phi3ForCausalLM(LlamaForCausalLM): pass diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 14dd63b5adb..3dfb1dc4117 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -95,6 +95,7 @@ class ServerArgs: disable_custom_all_reduce: bool = False enable_mixed_chunk: bool = False enable_torch_compile: bool = False + torchao_config: str = "" enable_p2p_check: bool = False enable_mla: bool = False triton_attention_reduce_in_fp32: bool = False @@ -443,7 +444,13 @@ def add_cli_args(parser: argparse.ArgumentParser): parser.add_argument( "--enable-torch-compile", action="store_true", - help="Optimize the model with torch.compile, experimental feature.", + help="Optimize the model with torch.compile. Experimental feature.", + ) + parser.add_argument( + "--torchao-config", + type=str, + default=ServerArgs.torchao_config, + help="Optimize the model with torchao. Experimental feature. Current choices are: int8dq, int8wo, int4wo-, fp8wo", ) parser.add_argument( "--enable-p2p-check", diff --git a/test/srt/test_eval_accuracy_mini.py b/test/srt/test_eval_accuracy_mini.py index 25aa0ca116b..6ddd97d9405 100644 --- a/test/srt/test_eval_accuracy_mini.py +++ b/test/srt/test_eval_accuracy_mini.py @@ -29,12 +29,12 @@ def test_mmlu(self): base_url=self.base_url, model=self.model, eval_name="mmlu", - num_examples=32, + num_examples=64, num_threads=32, ) metrics = run_eval(args) - assert metrics["score"] >= 0.6 + assert metrics["score"] >= 0.65 if __name__ == "__main__": diff --git a/test/srt/test_moe_eval_accuracy_large.py b/test/srt/test_moe_eval_accuracy_large.py index b15308dcec2..b6027b61cba 100644 --- a/test/srt/test_moe_eval_accuracy_large.py +++ b/test/srt/test_moe_eval_accuracy_large.py @@ -42,7 +42,7 @@ def test_mmlu(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.62, f"{metrics}" + assert metrics["score"] >= 0.625, f"{metrics}" def test_human_eval(self): args = SimpleNamespace( @@ -54,7 +54,7 @@ def test_human_eval(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.42, f"{metrics}" + assert metrics["score"] >= 0.425, f"{metrics}" def test_mgsm_en(self): args = SimpleNamespace( @@ -66,7 +66,7 @@ def test_mgsm_en(self): ) metrics = run_eval(args) - assert metrics["score"] >= 0.62, f"{metrics}" + assert metrics["score"] >= 0.625, f"{metrics}" if __name__ == "__main__": diff --git a/test/srt/test_torch_compile.py b/test/srt/test_torch_compile.py index e8cafa15d25..40f47d6b6b5 100644 --- a/test/srt/test_torch_compile.py +++ b/test/srt/test_torch_compile.py @@ -22,7 +22,7 @@ def setUpClass(cls): cls.model, cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, - other_args=["--enable-torch-compile", "--disable-radix-cache"], + other_args=["--enable-torch-compile"], ) @classmethod @@ -34,12 +34,12 @@ def test_mmlu(self): base_url=self.base_url, model=self.model, eval_name="mmlu", - num_examples=32, + num_examples=64, num_threads=32, ) metrics = run_eval(args) - assert metrics["score"] >= 0.6 + assert metrics["score"] >= 0.65 def run_decode(self, max_new_tokens): response = requests.post( diff --git a/test/srt/test_torchao.py b/test/srt/test_torchao.py new file mode 100644 index 00000000000..d2084e7d53a --- /dev/null +++ b/test/srt/test_torchao.py @@ -0,0 +1,73 @@ +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_child_process +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestTorchCompile(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=["--torchao-config", "int4wo-128"], + ) + + @classmethod + def tearDownClass(cls): + kill_child_process(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + assert metrics["score"] >= 0.65 + + def run_decode(self, max_new_tokens): + response = requests.post( + self.base_url + "/generate", + json={ + "text": "The capital of France is", + "sampling_params": { + "temperature": 0, + "max_new_tokens": max_new_tokens, + }, + "ignore_eos": True, + }, + ) + return response.json() + + def test_throughput(self): + import time + + max_tokens = 256 + + tic = time.time() + res = self.run_decode(max_tokens) + tok = time.time() + print(res["text"]) + throughput = max_tokens / (tok - tic) + print(f"Throughput: {throughput} tokens/s") + assert throughput >= 210 + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_triton_attn_backend.py b/test/srt/test_triton_attn_backend.py index a94ca921240..b3f65ac13a6 100644 --- a/test/srt/test_triton_attn_backend.py +++ b/test/srt/test_triton_attn_backend.py @@ -32,12 +32,12 @@ def test_mmlu(self): base_url=self.base_url, model=self.model, eval_name="mmlu", - num_examples=32, + num_examples=64, num_threads=32, ) metrics = run_eval(args) - assert metrics["score"] >= 0.6 + assert metrics["score"] >= 0.65 if __name__ == "__main__":