diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index ac69ab875b9..1d18d305fcb 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -50,6 +50,12 @@ def get_dtype_str(torch_dtype): raise NotImplementedError() +def get_top_logprobs(logits, k): + logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) + logprobs, top_indices = torch.topk(logprobs, k=k, dim=-1) + return logprobs + + @dataclass class ModelOutput: output_strs: List[str] = None @@ -108,7 +114,8 @@ def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): if prompts is not None: if self.is_generation: output_strs = [] - prefill_logprobs = [] + top_input_logprobs = [] + top_output_logprobs = [] for p in prompts: if isinstance(p, str): input_ids = self.tokenizer.encode( @@ -117,32 +124,43 @@ def start_model_process(self, in_queue, out_queue, model_path, torch_dtype): else: input_ids = torch.tensor([p], device="cuda") - output_ids = self.model.generate( - input_ids, do_sample=False, max_new_tokens=max_new_tokens + outputs = self.model.generate( + input_ids, + do_sample=False, + temperature=None, + top_p=None, + max_new_tokens=max_new_tokens, + return_dict_in_generate=True, + output_scores=True, ) output_strs.append( - self.tokenizer.decode(output_ids[0][len(input_ids[0]) :]) + self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :]) ) + # outputs.scores: (num_token, 1, vocab_size) + top_output_logprobs.append( + [ + get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist() + for logits in outputs.scores + ] + ) + del outputs - logits = self.model.forward(input_ids).logits[0] - logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) - logprobs, top_indices = torch.topk( - logprobs, k=NUM_TOP_LOGPROBS, dim=-1 + input_logits = self.model.forward(input_ids).logits[0] + top_input_logprobs.append( + get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist() ) - # print("index", top_indices) - prefill_logprobs.append(logprobs.tolist()) - del logits - del logprobs + del input_logits out_queue.put( ModelOutput( - output_strs=output_strs, top_input_logprobs=prefill_logprobs + output_strs=output_strs, + top_input_logprobs=top_input_logprobs, + top_output_logprobs=top_output_logprobs, ) ) else: logits = self.model.encode(prompts).tolist() - out_queue.put(ModelOutput(embed_logits=logits)) def forward( @@ -194,6 +212,7 @@ def forward( # the return value contains logprobs from prefill output_strs = [] top_input_logprobs = [] + top_output_logprobs = [] sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0} for prompt in prompts: response = self.runtime.generate( @@ -219,9 +238,17 @@ def forward( ] ] ) + top_output_logprobs.append( + [ + [tup[0] for tup in x[:NUM_TOP_LOGPROBS]] + for x in response["meta_info"]["output_top_logprobs"] + ] + ) return ModelOutput( - output_strs=output_strs, top_input_logprobs=top_input_logprobs + output_strs=output_strs, + top_input_logprobs=top_input_logprobs, + top_output_logprobs=top_output_logprobs, ) else: response = self.runtime.encode(prompts) diff --git a/test/srt/models/test_generation_models.py b/test/srt/models/test_generation_models.py index 08288c510c9..46854b3e869 100644 --- a/test/srt/models/test_generation_models.py +++ b/test/srt/models/test_generation_models.py @@ -21,9 +21,9 @@ from sglang.test.runners import DEFAULT_PROMPTS, HFRunner, SRTRunner MODELS = [ - ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 1), - ("google/gemma-2-2b", 1, 3, 3e-2, 1), - ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 1), + ("meta-llama/Meta-Llama-3.1-8B-Instruct", 1, 1.1, 3e-2, 4e-2, 1), + ("google/gemma-2-2b", 1, 3, 3e-2, 5e-2, 1), + ("Alibaba-NLP/gte-Qwen2-1.5B-instruct", 1, None, 6e-2, 4e-2, 1), ] TORCH_DTYPES = [torch.float16] @@ -70,6 +70,7 @@ def assert_close_prefill_logits_and_output_strs( torch_dtype, max_new_tokens, prefill_tolerance, + output_tolerance, rouge_threshold, long_context_tolerance, ) -> None: @@ -89,15 +90,37 @@ def assert_close_prefill_logits_and_output_strs( srt_outputs = srt_runner.forward(prompts, max_new_tokens=max_new_tokens) for i in range(len(prompts)): + # input logprobs comparison hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i]) srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i]) - - print("max_diff", torch.max(abs(hf_logprobs - srt_logprobs))) - if hf_logprobs.shape[0] <= 100: + input_len = hf_logprobs.shape[0] + print( + "prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs)) + ) + if input_len <= 100: assert torch.all( abs(hf_logprobs - srt_logprobs) < prefill_tolerance ), f"prefill logprobs are not all close with model_path={model_path} prompts={prompts} prefill_tolerance={prefill_tolerance}" + # output logprobs comparison + hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i]) + srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i]) + # print( + # "output logprobs diff", + # [ + # float(torch.max(abs(hf_logprobs[j] - srt_logprobs[j]))) + # for j in range(max_new_tokens) + # ], + # ) + print( + "output logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs)) + ) + if input_len <= 100: + assert torch.all( + abs(hf_logprobs - srt_logprobs) < output_tolerance + ), f"output logprobs are not all close with model_path={model_path} prompts={prompts}... output_tolerance={output_tolerance}" + + # output strings comparison print(f"hf_outputs.output_strs={hf_outputs.output_strs}") print(f"srt_outputs.output_strs={srt_outputs.output_strs}") rouge_l_scores = calculate_rouge_l( @@ -114,6 +137,7 @@ def test_prefill_logits_and_output_strs(self): tp_size, long_context_tolerance, prefill_tolerance, + output_tolerance, rouge_threshold, ) in MODELS: for torch_dtype in TORCH_DTYPES: @@ -125,6 +149,7 @@ def test_prefill_logits_and_output_strs(self): torch_dtype, max_new_tokens, prefill_tolerance=prefill_tolerance, + output_tolerance=output_tolerance, rouge_threshold=rouge_threshold, long_context_tolerance=long_context_tolerance, )