From 5008e5bbe9f1b62e6ddd45b81425327aa5eed90d Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 27 Jul 2024 19:23:10 -0700 Subject: [PATCH 1/5] Rename prefill_token_logprobs -> input_token_logprobs; decode_token_logprobs -> output_token_logprobs --- examples/usage/choices_logprob.py | 8 +-- examples/usage/cot_decoding.py | 14 ++-- examples/usage/json_logprobs.py | 26 +++---- .../sglang/lang/backend/runtime_endpoint.py | 8 +-- python/sglang/lang/interpreter.py | 8 +-- python/sglang/srt/layers/logits_processor.py | 68 +++++++++---------- .../managers/controller/cuda_graph_runner.py | 8 +-- .../srt/managers/controller/infer_batch.py | 12 ++-- .../srt/managers/controller/tp_worker.py | 50 +++++++------- .../sglang/srt/managers/tokenizer_manager.py | 16 ++--- .../sglang/srt/models/llama_classification.py | 6 +- python/sglang/srt/openai_api/adapter.py | 66 +++++++++--------- test/srt/test_httpserver_decode_stream.py | 8 +-- 13 files changed, 147 insertions(+), 151 deletions(-) diff --git a/examples/usage/choices_logprob.py b/examples/usage/choices_logprob.py index e261668f8a6..6cd733fe90a 100644 --- a/examples/usage/choices_logprob.py +++ b/examples/usage/choices_logprob.py @@ -20,8 +20,8 @@ def main(): print("questions:", question) print("choice:", state["tool"]) meta_info = state.get_meta_info("tool") - print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0]) - print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1]) + print("logprobs of choice 1", meta_info["input_token_logprobs"][0]) + print("logprobs of choice 2", meta_info["input_token_logprobs"][1]) print("-" * 50) # Run a batch @@ -34,8 +34,8 @@ def main(): print("questions:", question) print("choice:", state["tool"]) meta_info = state.get_meta_info("tool") - print("logprobs of choice 1", meta_info["prefill_token_logprobs"][0]) - print("logprobs of choice 2", meta_info["prefill_token_logprobs"][1]) + print("logprobs of choice 1", meta_info["input_token_logprobs"][0]) + print("logprobs of choice 2", meta_info["input_token_logprobs"][1]) print("-" * 50) diff --git a/examples/usage/cot_decoding.py b/examples/usage/cot_decoding.py index 5f9cd68d4f4..38fbde855bf 100644 --- a/examples/usage/cot_decoding.py +++ b/examples/usage/cot_decoding.py @@ -31,7 +31,7 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose): top_logprobs_num=get_top_k, return_text_in_logprobs=True, ) - logprobs = step_0.get_meta_info("get_top_k")["decode_top_logprobs"][0] + logprobs = step_0.get_meta_info("get_top_k")["output_top_logprobs"][0] print("Decoding step 0:", ", ".join(pformat(token[2]) for token in logprobs)) for idx, (f, token) in enumerate(zip(forks, logprobs)): @@ -55,9 +55,9 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose): ) # calculate probability disparity between the top and secondary tokens - x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]] - x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["decode_top_logprobs"]] - tokens = [xt[0][2] for xt in f.get_meta_info("answer")["decode_top_logprobs"]] + x1s = [exp(xt[0][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]] + x2s = [exp(xt[1][0]) for xt in f.get_meta_info("answer")["output_top_logprobs"]] + tokens = [xt[0][2] for xt in f.get_meta_info("answer")["output_top_logprobs"]] delta = (sum(x1s) - sum(x2s)) / len(x1s) # extract the answer span (without the '<|end_of_text|>' token) @@ -81,19 +81,19 @@ def cot_decoding(s, question, get_top_k, is_chat_model, verbose): answer_tokens = [ xt[0][2] for xt in answer_forks[idx].get_meta_info("answer_span")[ - "decode_top_logprobs" + "output_top_logprobs" ] ] answer_x1s = [ exp(xt[0][0]) for xt in answer_forks[idx].get_meta_info("answer_span")[ - "decode_top_logprobs" + "output_top_logprobs" ] ] answer_x2s = [ exp(xt[1][0]) for xt in answer_forks[idx].get_meta_info("answer_span")[ - "decode_top_logprobs" + "output_top_logprobs" ] ] diff --git a/examples/usage/json_logprobs.py b/examples/usage/json_logprobs.py index 6b5b9c8fcea..fa0e1b81f33 100644 --- a/examples/usage/json_logprobs.py +++ b/examples/usage/json_logprobs.py @@ -56,14 +56,14 @@ def srt_api_request(name): # fout.write(json.dumps(res, indent=4)) meta_info = res["meta_info"] - assert len(meta_info["prefill_token_logprobs"]) == len( - meta_info["prefill_top_logprobs"] + assert len(meta_info["input_token_logprobs"]) == len( + meta_info["input_top_logprobs"] ) - assert len(meta_info["decode_token_logprobs"]) == len( - meta_info["decode_top_logprobs"] + assert len(meta_info["output_token_logprobs"]) == len( + meta_info["output_top_logprobs"] ) - assert len(meta_info["prefill_token_logprobs"]) == meta_info["prompt_tokens"] - assert len(meta_info["decode_token_logprobs"]) == meta_info["completion_tokens"] - 1 + assert len(meta_info["input_token_logprobs"]) == meta_info["prompt_tokens"] + assert len(meta_info["output_token_logprobs"]) == meta_info["completion_tokens"] - 1 return res @@ -72,11 +72,11 @@ def pretty_print(res): meta_info = res["meta_info"] print("\n\n", "=" * 30, "Prefill", "=" * 30) - for i in range(len(meta_info["prefill_token_logprobs"])): - print(f"{str(meta_info['prefill_token_logprobs'][i][2].encode()): <20}", end="") + for i in range(len(meta_info["input_token_logprobs"])): + print(f"{str(meta_info['input_token_logprobs'][i][2].encode()): <20}", end="") top_ks = ( - [str(t[2].encode()) for t in meta_info["prefill_top_logprobs"][i]] - if meta_info["prefill_top_logprobs"][i] + [str(t[2].encode()) for t in meta_info["input_top_logprobs"][i]] + if meta_info["input_top_logprobs"][i] else [] ) for top_k in top_ks: @@ -84,9 +84,9 @@ def pretty_print(res): print() print("\n\n", "=" * 30, "Decode", "=" * 30) - for i in range(len(meta_info["decode_token_logprobs"])): - print(f"{str(meta_info['decode_token_logprobs'][i][2].encode()): <20}", end="") - top_ks = [str(t[2].encode()) for t in meta_info["decode_top_logprobs"][i]] + for i in range(len(meta_info["output_token_logprobs"])): + print(f"{str(meta_info['output_token_logprobs'][i][2].encode()): <20}", end="") + top_ks = [str(t[2].encode()) for t in meta_info["output_top_logprobs"][i]] for top_k in top_ks: print(f"{top_k: <15}", end="") print() diff --git a/python/sglang/lang/backend/runtime_endpoint.py b/python/sglang/lang/backend/runtime_endpoint.py index 772577336f9..929b3b6ada5 100644 --- a/python/sglang/lang/backend/runtime_endpoint.py +++ b/python/sglang/lang/backend/runtime_endpoint.py @@ -253,14 +253,14 @@ def select( r["meta_info"]["normalized_prompt_logprob"] for r in obj ] decision = choices[np.argmax(normalized_prompt_logprobs)] - prefill_token_logprobs = [r["meta_info"]["prefill_token_logprobs"] for r in obj] - decode_token_logprobs = [r["meta_info"]["decode_token_logprobs"] for r in obj] + input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj] + output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj] return ( decision, normalized_prompt_logprobs, - prefill_token_logprobs, - decode_token_logprobs, + input_token_logprobs, + output_token_logprobs, ) def concatenate_and_append(self, src_rids: List[str], dst_rid: str): diff --git a/python/sglang/lang/interpreter.py b/python/sglang/lang/interpreter.py index 573b9970bad..ddf755ca2a0 100644 --- a/python/sglang/lang/interpreter.py +++ b/python/sglang/lang/interpreter.py @@ -541,16 +541,16 @@ def _execute_select(self, expr: SglSelect): ( decision, normalized_prompt_logprobs, - prefill_token_logprobs, - decode_token_logprobs, + input_token_logprobs, + output_token_logprobs, ) = self.backend.select(self, expr.choices, expr.temperature) if expr.name is not None: name = expr.name self.variables[name] = decision self.meta_info[name] = { "normalized_prompt_logprobs": normalized_prompt_logprobs, - "prefill_token_logprobs": prefill_token_logprobs, - "decode_token_logprobs": decode_token_logprobs, + "input_token_logprobs": input_token_logprobs, + "output_token_logprobs": output_token_logprobs, } self.variable_event[name].set() self.text_ += decision diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index e9ed66f4fa6..ad95f0ca736 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -23,12 +23,12 @@ class LogitProcessorOutput: # The normlaized logprobs of prompts. shape: [#seq] normalized_prompt_logprobs: torch.Tensor # The logprobs of prefill tokens. shape: [#token, vocab_size] - prefill_token_logprobs: torch.Tensor + input_token_logprobs: torch.Tensor # The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id) - prefill_top_logprobs: List + input_top_logprobs: List # The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id) - decode_top_logprobs: List + output_top_logprobs: List @dataclasses.dataclass @@ -58,20 +58,16 @@ def __init__(self, config): self.tp_size = get_tensor_model_parallel_world_size() def _get_normalized_prompt_logprobs( - self, prefill_token_logprobs, logits_metadata: LogitsMetadata + self, input_token_logprobs, logits_metadata: LogitsMetadata ): - logprobs_cumsum = torch.cumsum( - prefill_token_logprobs, dim=0, dtype=torch.float32 - ) + logprobs_cumsum = torch.cumsum(input_token_logprobs, dim=0, dtype=torch.float32) start = logits_metadata.extend_start_loc.clone() end = start + logits_metadata.extend_seq_lens - 2 - start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1) - end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1) + start.clamp_(min=0, max=input_token_logprobs.shape[0] - 1) + end.clamp_(min=0, max=input_token_logprobs.shape[0] - 1) sum_logp = ( - logprobs_cumsum[end] - - logprobs_cumsum[start] - + prefill_token_logprobs[start] + logprobs_cumsum[end] - logprobs_cumsum[start] + input_token_logprobs[start] ) normalized_prompt_logprobs = sum_logp / ( (logits_metadata.extend_seq_lens - 1).clamp(min=1) @@ -83,34 +79,34 @@ def _get_normalized_prompt_logprobs( def get_top_logprobs(all_logprobs, logits_metadata: LogitsMetadata): # TODO: vectorize the code below if logits_metadata.forward_mode == ForwardMode.DECODE: - decode_top_logprobs = [] + output_top_logprobs = [] for i in range(all_logprobs.shape[0]): k = logits_metadata.top_logprobs_nums[i] t = all_logprobs[i].topk(k) v_cpu = t.values.tolist() p_cpu = t.indices.tolist() - decode_top_logprobs.append(list(zip(v_cpu, p_cpu))) - return None, decode_top_logprobs + output_top_logprobs.append(list(zip(v_cpu, p_cpu))) + return None, output_top_logprobs else: - prefill_top_logprobs, decode_top_logprobs = [], [] + input_top_logprobs, output_top_logprobs = [], [] pt = 0 extend_seq_lens_cpu = logits_metadata.extend_seq_lens.tolist() for i, extend_seq_len in enumerate(extend_seq_lens_cpu): if extend_seq_len == 0: - prefill_top_logprobs.append([]) - decode_top_logprobs.append([]) + input_top_logprobs.append([]) + output_top_logprobs.append([]) continue k = logits_metadata.top_logprobs_nums[i] t = all_logprobs[pt : pt + extend_seq_len].topk(k) vs_cpu = t.values.tolist() ps_cpu = t.indices.tolist() - prefill_top_logprobs.append( + input_top_logprobs.append( [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)] ) - decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1]))) + output_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1]))) pt += extend_seq_len - return prefill_top_logprobs, decode_top_logprobs + return input_top_logprobs, output_top_logprobs def forward( self, @@ -150,9 +146,9 @@ def forward( next_token_logits=last_logits, next_token_logprobs=None, normalized_prompt_logprobs=None, - prefill_token_logprobs=None, - prefill_top_logprobs=None, - decode_top_logprobs=None, + input_token_logprobs=None, + input_top_logprobs=None, + output_top_logprobs=None, ) else: # When logprob is requested, compute the logits for all tokens. @@ -164,19 +160,19 @@ def forward( x > 0 for x in logits_metadata.top_logprobs_nums ) if return_top_logprob: - decode_top_logprobs = self.get_top_logprobs( + output_top_logprobs = self.get_top_logprobs( last_logprobs, logits_metadata )[1] else: - decode_top_logprobs = None + output_top_logprobs = None return LogitProcessorOutput( next_token_logits=last_logits, next_token_logprobs=last_logprobs, normalized_prompt_logprobs=None, - prefill_token_logprobs=None, - prefill_top_logprobs=None, - decode_top_logprobs=decode_top_logprobs, + input_token_logprobs=None, + input_top_logprobs=None, + output_top_logprobs=output_top_logprobs, ) else: all_logits = torch.matmul(hidden_states, weight.T) @@ -193,32 +189,32 @@ def forward( x > 0 for x in logits_metadata.top_logprobs_nums ) if return_top_logprob: - prefill_top_logprobs, decode_top_logprobs = self.get_top_logprobs( + input_top_logprobs, output_top_logprobs = self.get_top_logprobs( all_logprobs, logits_metadata ) else: - prefill_top_logprobs = decode_top_logprobs = None + input_top_logprobs = output_top_logprobs = None last_logprobs = all_logprobs[last_index] # Compute the logprobs and normalized logprobs for the prefill tokens. # Note that we pad a zero at the end of each sequence for easy computation. - prefill_token_logprobs = all_logprobs[ + input_token_logprobs = all_logprobs[ torch.arange(all_logprobs.shape[0], device="cuda"), torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]), ] normalized_prompt_logprobs = self._get_normalized_prompt_logprobs( - prefill_token_logprobs, logits_metadata + input_token_logprobs, logits_metadata ) return LogitProcessorOutput( next_token_logits=last_logits, next_token_logprobs=last_logprobs, normalized_prompt_logprobs=normalized_prompt_logprobs, - prefill_token_logprobs=prefill_token_logprobs, - prefill_top_logprobs=prefill_top_logprobs, - decode_top_logprobs=decode_top_logprobs, + input_token_logprobs=input_token_logprobs, + input_top_logprobs=input_top_logprobs, + output_top_logprobs=output_top_logprobs, ) diff --git a/python/sglang/srt/managers/controller/cuda_graph_runner.py b/python/sglang/srt/managers/controller/cuda_graph_runner.py index 2bdb33cff69..b795c853e1a 100644 --- a/python/sglang/srt/managers/controller/cuda_graph_runner.py +++ b/python/sglang/srt/managers/controller/cuda_graph_runner.py @@ -226,9 +226,9 @@ def replay(self, batch: Batch): next_token_logits=output.next_token_logits[:raw_bs], next_token_logprobs=None, normalized_prompt_logprobs=None, - prefill_token_logprobs=None, - prefill_top_logprobs=None, - decode_top_logprobs=None, + input_token_logprobs=None, + input_top_logprobs=None, + output_top_logprobs=None, ) # Extract logprobs @@ -242,7 +242,7 @@ def replay(self, batch: Batch): forward_mode=ForwardMode.DECODE, top_logprobs_nums=batch.top_logprobs_nums, ) - output.decode_top_logprobs = LogitsProcessor.get_top_logprobs( + output.output_top_logprobs = LogitsProcessor.get_top_logprobs( output.next_token_logprobs, logits_metadata )[1] diff --git a/python/sglang/srt/managers/controller/infer_batch.py b/python/sglang/srt/managers/controller/infer_batch.py index 5ef3552ba89..47dd043a5b7 100644 --- a/python/sglang/srt/managers/controller/infer_batch.py +++ b/python/sglang/srt/managers/controller/infer_batch.py @@ -124,10 +124,10 @@ def __init__(self, rid, origin_input_text, origin_input_ids): self.logprob_start_len = 0 self.top_logprobs_num = 0 self.normalized_prompt_logprob = None - self.prefill_token_logprobs = None - self.prefill_top_logprobs = None - self.decode_token_logprobs = [] - self.decode_top_logprobs = [] + self.input_token_logprobs = None + self.input_top_logprobs = None + self.output_token_logprobs = [] + self.output_top_logprobs = [] # The tokens is prefilled but need to be considered as decode tokens # and should be updated for the decode logprobs self.last_update_decode_tokens = 0 @@ -244,8 +244,8 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state): k = k + 1 else: break - self.decode_token_logprobs = self.decode_token_logprobs[:k] - self.decode_top_logprobs = self.decode_top_logprobs[:k] + self.output_token_logprobs = self.output_token_logprobs[:k] + self.output_top_logprobs = self.output_top_logprobs[:k] self.logprob_start_len = prompt_tokens + k self.last_update_decode_tokens = len(self.output_ids) - k diff --git a/python/sglang/srt/managers/controller/tp_worker.py b/python/sglang/srt/managers/controller/tp_worker.py index eb6f826cc01..8e0525b07c1 100644 --- a/python/sglang/srt/managers/controller/tp_worker.py +++ b/python/sglang/srt/managers/controller/tp_worker.py @@ -455,7 +455,7 @@ def forward_prefill_batch(self, batch: Batch): torch.arange(len(next_token_ids), device=next_token_ids.device), next_token_ids, ].tolist() - output.prefill_token_logprobs = output.prefill_token_logprobs.tolist() + output.input_token_logprobs = output.input_token_logprobs.tolist() output.normalized_prompt_logprobs = ( output.normalized_prompt_logprobs.tolist() ) @@ -481,24 +481,24 @@ def add_logprob_return_values(self, i, req, pt, next_token_ids, output): if req.normalized_prompt_logprob is None: req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] - if req.prefill_token_logprobs is None: + if req.input_token_logprobs is None: # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored. - req.prefill_token_logprobs = list( + req.input_token_logprobs = list( zip( - output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1], + output.input_token_logprobs[pt : pt + req.extend_input_len - 1], req.input_ids[-req.extend_input_len + 1 :], ) ) if req.logprob_start_len == 0: - req.prefill_token_logprobs = [ + req.input_token_logprobs = [ (None, req.input_ids[0]) - ] + req.prefill_token_logprobs + ] + req.input_token_logprobs if req.last_update_decode_tokens != 0: - req.decode_token_logprobs.extend( + req.output_token_logprobs.extend( list( zip( - output.prefill_token_logprobs[ + output.input_token_logprobs[ pt + req.extend_input_len - req.last_update_decode_tokens : pt @@ -510,21 +510,21 @@ def add_logprob_return_values(self, i, req, pt, next_token_ids, output): ) ) - req.decode_token_logprobs.append( + req.output_token_logprobs.append( (output.next_token_logprobs[i], next_token_ids[i]) ) if req.top_logprobs_num > 0: - if req.prefill_top_logprobs is None: - req.prefill_top_logprobs = output.prefill_top_logprobs[i] + if req.input_top_logprobs is None: + req.input_top_logprobs = output.input_top_logprobs[i] if req.logprob_start_len == 0: - req.prefill_top_logprobs = [None] + req.prefill_top_logprobs + req.input_top_logprobs = [None] + req.input_top_logprobs if req.last_update_decode_tokens != 0: - req.decode_top_logprobs.extend( - output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :] + req.output_top_logprobs.extend( + output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :] ) - req.decode_top_logprobs.append(output.decode_top_logprobs[i]) + req.output_top_logprobs.append(output.output_top_logprobs[i]) def cache_filled_batch(self, batch: Batch): req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy() @@ -589,11 +589,11 @@ def forward_decode_batch(self, batch: Batch): req.check_finished() if req.return_logprob: - req.decode_token_logprobs.append( + req.output_token_logprobs.append( (next_token_logprobs[i], next_token_id) ) if req.top_logprobs_num > 0: - req.decode_top_logprobs.append(output.decode_top_logprobs[i]) + req.output_top_logprobs.append(output.output_top_logprobs[i]) self.handle_finished_requests(batch) @@ -645,16 +645,16 @@ def handle_finished_requests(self, batch: Batch): } if req.return_logprob: ( - meta_info["prefill_token_logprobs"], - meta_info["decode_token_logprobs"], - meta_info["prefill_top_logprobs"], - meta_info["decode_top_logprobs"], + meta_info["input_token_logprobs"], + meta_info["output_token_logprobs"], + meta_info["input_top_logprobs"], + meta_info["output_top_logprobs"], meta_info["normalized_prompt_logprob"], ) = ( - req.prefill_token_logprobs, - req.decode_token_logprobs, - req.prefill_top_logprobs, - req.decode_top_logprobs, + req.input_token_logprobs, + req.output_token_logprobs, + req.input_top_logprobs, + req.output_top_logprobs, req.normalized_prompt_logprob, ) output_meta_info.append(meta_info) diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 34890d699cb..953520986f5 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -448,23 +448,23 @@ def convert_logprob_style( return_text_in_logprobs: bool, ): if return_logprob: - ret["meta_info"]["prefill_token_logprobs"] = self.detokenize_logprob_tokens( - ret["meta_info"]["prefill_token_logprobs"], return_text_in_logprobs + ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens( + ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs ) - ret["meta_info"]["decode_token_logprobs"] = self.detokenize_logprob_tokens( - ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs + ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens( + ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs ) if top_logprobs_num > 0: - ret["meta_info"]["prefill_top_logprobs"] = ( + ret["meta_info"]["input_top_logprobs"] = ( self.detokenize_top_logprobs_tokens( - ret["meta_info"]["prefill_top_logprobs"], + ret["meta_info"]["input_top_logprobs"], return_text_in_logprobs, ) ) - ret["meta_info"]["decode_top_logprobs"] = ( + ret["meta_info"]["output_top_logprobs"] = ( self.detokenize_top_logprobs_tokens( - ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs + ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs ) ) return ret diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index 96b1ac01e05..a12bbc91a9d 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -54,9 +54,9 @@ def forward( next_token_logits=scores, next_token_logprobs=scores, normalized_prompt_logprobs=scores, - prefill_token_logprobs=torch.ones_like(input_ids), - prefill_top_logprobs=None, - decode_top_logprobs=None, + input_token_logprobs=torch.ones_like(input_ids), + input_top_logprobs=None, + output_top_logprobs=None, ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 760a46e9b02..6f354cef27a 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -140,29 +140,29 @@ async def generate_stream_resp(): if request.logprobs: # The first chunk and echo is enabled. if not stream_buffer and request.echo: - prefill_token_logprobs = content["meta_info"][ - "prefill_token_logprobs" + input_token_logprobs = content["meta_info"][ + "input_token_logprobs" ] - prefill_top_logprobs = content["meta_info"][ - "prefill_top_logprobs" + input_top_logprobs = content["meta_info"][ + "input_top_logprobs" ] else: - prefill_token_logprobs = None - prefill_top_logprobs = None + input_token_logprobs = None + input_top_logprobs = None logprobs = to_openai_style_logprobs( - prefill_token_logprobs=prefill_token_logprobs, - prefill_top_logprobs=prefill_top_logprobs, - decode_token_logprobs=content["meta_info"][ - "decode_token_logprobs" + input_token_logprobs=input_token_logprobs, + input_top_logprobs=input_top_logprobs, + output_token_logprobs=content["meta_info"][ + "output_token_logprobs" ][n_prev_token:], - decode_top_logprobs=content["meta_info"][ - "decode_top_logprobs" + output_top_logprobs=content["meta_info"][ + "output_top_logprobs" ][n_prev_token:], ) n_prev_token = len( - content["meta_info"]["decode_token_logprobs"] + content["meta_info"]["output_token_logprobs"] ) else: logprobs = None @@ -218,17 +218,17 @@ async def generate_stream_resp(): if request.logprobs: if request.echo: - prefill_token_logprobs = ret_item["meta_info"]["prefill_token_logprobs"] - prefill_top_logprobs = ret_item["meta_info"]["prefill_top_logprobs"] + input_token_logprobs = ret_item["meta_info"]["input_token_logprobs"] + input_top_logprobs = ret_item["meta_info"]["input_top_logprobs"] else: - prefill_token_logprobs = None - prefill_top_logprobs = None + input_token_logprobs = None + input_top_logprobs = None logprobs = to_openai_style_logprobs( - prefill_token_logprobs=prefill_token_logprobs, - prefill_top_logprobs=prefill_top_logprobs, - decode_token_logprobs=ret_item["meta_info"]["decode_token_logprobs"], - decode_top_logprobs=ret_item["meta_info"]["decode_top_logprobs"], + input_token_logprobs=input_token_logprobs, + input_top_logprobs=input_top_logprobs, + output_token_logprobs=ret_item["meta_info"]["output_token_logprobs"], + output_top_logprobs=ret_item["meta_info"]["output_top_logprobs"], ) else: logprobs = None @@ -401,10 +401,10 @@ async def generate_stream_resp(): def to_openai_style_logprobs( - prefill_token_logprobs=None, - decode_token_logprobs=None, - prefill_top_logprobs=None, - decode_top_logprobs=None, + input_token_logprobs=None, + output_token_logprobs=None, + input_top_logprobs=None, + output_top_logprobs=None, ): ret_logprobs = LogProbs() @@ -425,13 +425,13 @@ def append_top_logprobs(top_logprobs): else: ret_logprobs.top_logprobs.append(None) - if prefill_token_logprobs is not None: - append_token_logprobs(prefill_token_logprobs) - if decode_token_logprobs is not None: - append_token_logprobs(decode_token_logprobs) - if prefill_top_logprobs is not None: - append_top_logprobs(prefill_top_logprobs) - if decode_top_logprobs is not None: - append_top_logprobs(decode_top_logprobs) + if input_token_logprobs is not None: + append_token_logprobs(input_token_logprobs) + if output_token_logprobs is not None: + append_token_logprobs(output_token_logprobs) + if input_top_logprobs is not None: + append_top_logprobs(input_top_logprobs) + if output_top_logprobs is not None: + append_top_logprobs(output_top_logprobs) return ret_logprobs diff --git a/test/srt/test_httpserver_decode_stream.py b/test/srt/test_httpserver_decode_stream.py index 38f090b7d1b..955c368d154 100644 --- a/test/srt/test_httpserver_decode_stream.py +++ b/test/srt/test_httpserver_decode_stream.py @@ -40,14 +40,14 @@ def test_decode_stream(url, return_logprob, top_logprobs_num): data = json.loads(chunk[5:].strip("\n")) if return_logprob: - assert data["meta_info"]["prefill_token_logprobs"] is not None - assert data["meta_info"]["decode_token_logprobs"] is not None + assert data["meta_info"]["input_token_logprobs"] is not None + assert data["meta_info"]["output_token_logprobs"] is not None assert data["meta_info"]["normalized_prompt_logprob"] is not None for logprob, token_id, token_text in data["meta_info"][ - "decode_token_logprobs" + "output_token_logprobs" ][prev:]: print(f"{token_text:12s}\t{logprob}\t{token_id}", flush=True) - prev = len(data["meta_info"]["decode_token_logprobs"]) + prev = len(data["meta_info"]["output_token_logprobs"]) else: output = data["text"].strip() print(output[prev:], end="", flush=True) From d4c42c1a3a22b73b8877293eeef426ca1661f62a Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 27 Jul 2024 19:27:22 -0700 Subject: [PATCH 2/5] fixg --- python/sglang/srt/layers/logits_processor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index ad95f0ca736..ec63b4a142d 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -22,12 +22,12 @@ class LogitProcessorOutput: # The normlaized logprobs of prompts. shape: [#seq] normalized_prompt_logprobs: torch.Tensor - # The logprobs of prefill tokens. shape: [#token, vocab_size] + # The logprobs of input tokens. shape: [#token, vocab_size] input_token_logprobs: torch.Tensor - # The logprob and id of the top-k tokens in prefill positions. shape [#seq, #token, k] of Tuple(logprob, token_id) + # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k] of Tuple(logprob, token_id) input_top_logprobs: List - # The logprob and id of the top-k tokens in decode positions. shape [#seq, #token, k] of Tuple(logprob, token_id) + # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k] of Tuple(logprob, token_id) output_top_logprobs: List From 09d32777ed5569c296dbd9a7edbaacce7b81b5c2 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 27 Jul 2024 19:33:15 -0700 Subject: [PATCH 3/5] improve docs --- docs/sampling_params.md | 43 +++++++++++++------------ python/sglang/srt/managers/io_struct.py | 4 +-- 2 files changed, 24 insertions(+), 23 deletions(-) diff --git a/docs/sampling_params.md b/docs/sampling_params.md index 6299c59539d..0ea07c01759 100644 --- a/docs/sampling_params.md +++ b/docs/sampling_params.md @@ -13,7 +13,7 @@ class GenerateReqInput: # The image input. It can be a file name, a url, or base64 encoded string. # See also python/sglang/srt/utils.py:load_image. image_data: Optional[Union[List[str], str]] = None - # The sampling_params. + # The sampling_params. See descriptions below. sampling_params: Union[List[Dict], Dict] = None # The request id. rid: Optional[Union[List[str], str]] = None @@ -23,7 +23,7 @@ class GenerateReqInput: logprob_start_len: Optional[Union[List[int], int]] = None # The number of top logprobs to return. top_logprobs_num: Optional[Union[List[int], int]] = None - # Whether to detokenize tokens in logprobs. + # Whether to detokenize tokens in text in the returned logprobs. return_text_in_logprobs: bool = False # Whether to stream output. stream: bool = False @@ -32,27 +32,28 @@ class GenerateReqInput: The `sampling_params` follows this format ```python -class SamplingParams: - def __init__( - self, - max_new_tokens: int = 16, - stop: Optional[Union[str, List[str]]] = None, - temperature: float = 1.0, - top_p: float = 1.0, - top_k: int = -1, - frequency_penalty: float = 0.0, - presence_penalty: float = 0.0, - ignore_eos: bool = False, - skip_special_tokens: bool = True, - dtype: Optional[str] = None, - regex: Optional[str] = None, - ) -> None: +# The maximum number of output tokens +max_new_tokens: int = 16, +# Stop when hitting any of the strings in this list. +stop: Optional[Union[str, List[str]]] = None, +# Sampling temperature +temperature: float = 1.0, +# Top-p sampling +top_p: float = 1.0, +# Top-k sampling +top_k: int = -1, +# Whether to ignore EOS token. +ignore_eos: bool = False, +# Whether to skip the special tokens during detokenization. +skip_special_tokens: bool = True, +# Whether to add spaces between special tokens during detokenization. +spaces_between_special_tokens: bool = True, +# Constrains the output to follow a given regular expression. +regex: Optional[str] = None, +# Do parallel sampling and return `n` outputs. +n: int = 1, ``` -- `max_new_tokens`, `stop`, `temperature`, `top_p`, `top_k` are common sampling parameters. -- `ignore_eos` means ignoring the EOS token and continue decoding, which is helpful for benchmarking purposes. -- `regex` constrains the output to follow a given regular expression. - ## Examples ### Normal diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 40226d1d1dc..95b68d2d6e4 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -20,7 +20,7 @@ class GenerateReqInput: # The image input. It can be a file name, a url, or base64 encoded string. # See also python/sglang/srt/utils.py:load_image. image_data: Optional[Union[List[str], str]] = None - # The sampling_params. + # The sampling_params. See descriptions below. sampling_params: Union[List[Dict], Dict] = None # The request id. rid: Optional[Union[List[str], str]] = None @@ -30,7 +30,7 @@ class GenerateReqInput: logprob_start_len: Optional[Union[List[int], int]] = None # The number of top logprobs to return. top_logprobs_num: Optional[Union[List[int], int]] = None - # Whether to detokenize tokens in logprobs. + # Whether to detokenize tokens in text in the returned logprobs. return_text_in_logprobs: bool = False # Whether to stream output. stream: bool = False From 9d3a5efcdffd5eeb289d581fe9cf6fc7bad763e7 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 27 Jul 2024 19:47:36 -0700 Subject: [PATCH 4/5] update --- test/srt/test_httpserver_decode.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/test/srt/test_httpserver_decode.py b/test/srt/test_httpserver_decode.py index 7e169f3e423..1c3bdcb88f6 100644 --- a/test/srt/test_httpserver_decode.py +++ b/test/srt/test_httpserver_decode.py @@ -13,14 +13,15 @@ import requests -def test_decode(url, return_logprob, top_logprobs_num, return_text): +def test_decode(url, return_logprob=False, top_logprobs_num=0, return_text=False, n=1): response = requests.post( url + "/generate", json={ "text": "The capital of France is", "sampling_params": { - "temperature": 0, + "temperature": 0 if n == 1 else 0.5, "max_new_tokens": 32, + "n": n, }, "stream": False, "return_logprob": return_logprob, @@ -41,8 +42,9 @@ def test_decode(url, return_logprob, top_logprobs_num, return_text): url = f"{args.host}:{args.port}" - test_decode(url, False, 0, False) - test_decode(url, True, 0, False) - test_decode(url, True, 0, True) - test_decode(url, True, 3, False) - test_decode(url, True, 3, True) + test_decode(url) + test_decode(url, n=3) + + for top_logprobs_num in [0, 3]: + for return_text in [True, False]: + test_decode(url, return_logprob=True, top_logprobs_num=top_logprobs_num, return_text=return_text) From d582d5162bd0427a334c4182a3b05cf80d8be177 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 27 Jul 2024 19:49:08 -0700 Subject: [PATCH 5/5] lint --- test/srt/test_httpserver_decode.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/test/srt/test_httpserver_decode.py b/test/srt/test_httpserver_decode.py index 1c3bdcb88f6..57517a15b00 100644 --- a/test/srt/test_httpserver_decode.py +++ b/test/srt/test_httpserver_decode.py @@ -47,4 +47,9 @@ def test_decode(url, return_logprob=False, top_logprobs_num=0, return_text=False for top_logprobs_num in [0, 3]: for return_text in [True, False]: - test_decode(url, return_logprob=True, top_logprobs_num=top_logprobs_num, return_text=return_text) + test_decode( + url, + return_logprob=True, + top_logprobs_num=top_logprobs_num, + return_text=return_text, + )