Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for logprobs in OpenAI chat API #852

Merged

Conversation

yichuan520030910320
Copy link
Collaborator

Thank you for your contribution, we really appreciate it. The following instructions will help improve your pull request and make it easier to receive feedback. If there are any items you don't understand, don't worry. Just submit the pull request and ask the maintainers for help.

Motivation

Fix #839

Modification

Add support for logprobs in OpenAI chat API and cater real OAI API output

Checklist

  1. Ensure pre-commit pre-commit run --all-files or other linting tools are used to fix potential lint issues.
  2. Confirm that modifications are covered by complete unit tests. If not, please add more unit tests for correctness.
  3. Modify documentation as needed, such as docstrings or example tutorials.

@yichuan520030910320
Copy link
Collaborator Author

cc @merrymercy @hnyls2002 @Ying1123 @zhyncs for review, I think it is ready to merge

@isukharev
Copy link

I've just tested the code from this branch and encountered an error:

Exception in ControllerSingle:
Traceback (most recent call last):
  File "/home/user/sources/sglang/python/sglang/srt/managers/controller_single.py", line 166, in start_controller_process
    controller.loop_for_forward()
  File "/home/user/sources/sglang/python/sglang/srt/managers/controller_single.py", line 103, in loop_for_forward
    out_pyobjs = self.tp_server.exposed_step(recv_reqs)
  File "/home/user/sources/sglang/python/sglang/srt/managers/tp_worker.py", line 209, in exposed_step
    self.forward_step()
  File "/home/user/sglang/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/sources/sglang/python/sglang/srt/managers/tp_worker.py", line 225, in forward_step
    self.forward_prefill_batch(new_batch)
  File "/home/user/sources/sglang/python/sglang/srt/managers/tp_worker.py", line 532, in forward_prefill_batch
    output = self.model_runner.forward(batch, ForwardMode.EXTEND)
  File "/home/user/sources/sglang/python/sglang/srt/model_executor/model_runner.py", line 386, in forward
    return self.forward_extend(batch)
  File "/home/user/sglang/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/sources/sglang/python/sglang/srt/model_executor/model_runner.py", line 354, in forward_extend
    return self.model.forward(
  File "/home/user/sglang/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/user/sources/sglang/python/sglang/srt/models/qwen2.py", line 288, in forward
    return self.logits_processor(
  File "/home/user/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/user/sglang/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/user/sources/sglang/python/sglang/srt/layers/logits_processor.py", line 209, in forward
    all_logits = all_logits[:, : self.config.vocab_size].float()
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 8.71 GiB. GPU

@hnyls2002
Copy link
Collaborator

@isukharev Computing logprobs take more memory, try to reduce --mem-fraction-static and we are figuring out how to trade-off between the memory usage and bigger batch size.

@isukharev
Copy link

@hnyls2002 Thanks! After reducing --mem-fraction-static to 0.7, the previous error gone, but the generation speed relative to the current master version dropped by a factor of 15, and now the cache hit rate is consistently 0.00% instead of the previous cache hit rate around 90%.

@Ying1123 Ying1123 merged commit ca600e8 into sgl-project:main Aug 1, 2024
2 checks passed
@Ying1123
Copy link
Member

Ying1123 commented Aug 1, 2024

@isukharev When you want to get the logprob of the prompt, the prefix radix cache will be turned off. This is because we only cache the KV cache, not the logits.

Do you need logprob for prompts or do you only need the logprob for generation? If you only need the logprob for generation, we can implement an additional flag just for your use case. Then you should see a similar cache hit rate.

@isukharev
Copy link

@Ying1123 Only need this for generation, we are using LLM as a CrossEncoder as shown here:
https://github.com/openai/openai-cookbook/blob/main/examples/Search_reranking_with_cross-encoders.ipynb

@Ying1123
Copy link
Member

Ying1123 commented Aug 1, 2024

@yichuan520030910320 @hnyls2002 It should be easy to do. Can you implement this feature?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug] Chat completions logprobs support
5 participants