Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] merge beam search implementations #9296

Merged
merged 2 commits into from
Oct 14, 2024

Conversation

LunrEclipse
Copy link
Contributor

Merged implementation of AsyncEngine and MQLLMEngine's beam_search into EngineClient(Protocol)

Manually testing conducted to verify that requests are still ran in parallel and output is correct.

server side:

$ vllm serve meta-llama/Meta-Llama-3-8B

client side:

Completion

from openai import OpenAI

client = OpenAI(
    base_url="http://localhost:8000/v1",
    api_key="key123",
)

prompt = "Capital of France is"

try:
    completion = client.completions.create(
        model="meta-llama/Meta-Llama-3-8B",
        prompt=prompt,
        max_tokens=4,
        extra_body={'use_beam_search': True, 'best_of': 3}
    )
    print(completion.choices[0].text)
except Exception as e:
    print(e)

Chat

from openai import OpenAI

client = OpenAI(
    base_url="http://localhost:8000/v1",
    api_key="key123",
)

prompt = "Capital of France is"

try:
    completion = client.chat.completions.create(
        model="meta-llama/Meta-Llama-3-8B",
        messages = [
            {"role": "system", "content": "You are a helpful AI assistant."},
            {"role": "user", "content": prompt}
        ],
        max_tokens=10,
        extra_body={'use_beam_search': True, 'best_of': 3, 'temperature': 0}
    )
    print(completion)
    print(completion.choices[0].message.content)
except Exception as e:
    print(e)

@LunrEclipse LunrEclipse marked this pull request as ready for review October 11, 2024 20:30
@LunrEclipse LunrEclipse changed the title merge beam search implementations [Frontend] merge beam search implementations Oct 11, 2024
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @LunrEclipse, this looks good to me. I think along with this we should change EngineClient from a Protocol to an ABC. It doesn't make much sense to have a method impl in an Protocol.

vllm/engine/protocol.py Outdated Show resolved Hide resolved
vllm/engine/protocol.py Outdated Show resolved Hide resolved
logprob_obj.logprob)

if token_id == tokenizer.eos_token_id and \
not ignore_eos:
Copy link
Contributor

@nFunctor nFunctor Oct 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since I tried to implement stop logic elsewhere I'd like to know why are we dealing with eos like that instead of putting ignore_eos into sampling params? Strictly speaking this is not the goal of the PR so feel free to ignore this comment. Thanks

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nFunctor yes this PR is just consolidating the existing logic, let's address that in your follow-on one.

@youkaichao
Copy link
Member

@njhill thanks for shepherding this pr!
@LunrEclipse please address the review from @njhill .

I'll be afk and will hand it over to @njhill for review.

@LunrEclipse
Copy link
Contributor Author

@njhill Thank you for the review! I've gone ahead and pushed changes based on your feedback

@njhill njhill added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 14, 2024
Copy link
Member

@njhill njhill left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @LunrEclipse.

Not related to this PR specifically, but couldn't the beam search impl still be kept behind the EngineClient.generate API? I.e. we just intercept the existing beam_search and associated params in SamplingParams ... so that the outward function remains the same?

@LunrEclipse
Copy link
Contributor Author

@njhill Yeah, it's definitely doable if we add logic to the EngineClient.generate methods to check if beam_search is true and then yield different results there than checking inside the engine itself.

@simon-mo simon-mo merged commit 4d31cd4 into vllm-project:main Oct 14, 2024
66 of 69 checks passed
cumulative_logprob=beam.cum_logprob,
token_ids=beam.tokens,
index=i,
logprobs=beam.cum_logprob,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@njhill I hadn't got to it yet, but FYI, mypy complains about this line. It's passing a float where it expects a dict, at least according to the typing.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @russellb yes this looks wrong! Though not really due to this PR which just moved/consolidated the existing logic.

Probably we should keep a logprobs list in BeamSearchSequence in addition to tokens, and set this.

I think this new external beam search impl still needs a bit more work in general.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I knew you had just moved the code. I just wanted to highlight it in case it was a super quick fix for you. Thanks for sharing your thoughts! I'll probably get to it at some point as I keep hacking through the type checking. It seems pretty valuable since it's found multiple bugs in my digging so far!

@youkaichao
Copy link
Member

Thanks @LunrEclipse.

Not related to this PR specifically, but couldn't the beam search impl still be kept behind the EngineClient.generate API? I.e. we just intercept the existing beam_search and associated params in SamplingParams ... so that the outward function remains the same?

it is possible following the spirit of #9302 . we can just create another BeamSearchSequenceGroup class. @LunrEclipse @njhill

@njhill
Copy link
Member

njhill commented Oct 23, 2024

@youkaichao yes I think we should refactor things a bit.. and also move the impl into beam_search.py instead of protocol.py, etc.

charlifu pushed a commit to charlifu/vllm that referenced this pull request Oct 23, 2024
Alvant pushed a commit to compressa-ai/vllm that referenced this pull request Oct 26, 2024
garg-amit pushed a commit to garg-amit/vllm that referenced this pull request Oct 28, 2024
sumitd2 pushed a commit to sumitd2/vllm that referenced this pull request Nov 14, 2024
KuntaiDu pushed a commit to KuntaiDu/vllm that referenced this pull request Nov 20, 2024
mfournioux pushed a commit to mfournioux/vllm that referenced this pull request Nov 20, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants