-
-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Model]: Add transformers
backend support
#11330
Conversation
Co-authored-by: Isotr0py <[email protected]>
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Hello @ArthurZucker! This is very exciting! I know this PR is still a draft, but could you provide some context on the scope of this effort? Is it to support any model on |
Yep, overall this should support We are refactor our models to make sure it's propagated to as many models as possible! |
Might not have time to finish this week, will make it ready for next week 🎄 |
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
…orted Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Harry Mellor <[email protected]>
Benchmarks on A100 using the following command: python benchmarks/benchmark_throughput.py --backend vllm --model meta-llama/Llama-3.1-8B-Instruct --dataset ShareGPT_V3_unfiltered_cleaned_split.json Results:
|
Signed-off-by: Harry Mellor <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work, LGTM! I've tested locally with Llama by running gsm8k evals, where I see good accuracy and slightly less throughput as we would expect.
vLLM impl:
lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct --tasks gsm8k --num_fewshot 5 --batch_size auto
...
Processed prompts: 100%|████████████| 1319/1319 [00:45<00:00, 29.09it/s, est. speed input: 25375.03 toks/s, output: 2839.50 toks/s]
Running generate_until requests: 100%|████████████| 1319/1319 [00:45<00:00, 28.98it/s]
2025-01-31:20:32:00,673 INFO [evaluation_tracker.py:269] Output path not provided, skipping saving results aggregated
vllm (pretrained=meta-llama/Llama-3.1-8B-Instruct), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.7801|± |0.0114|
| | |strict-match | 5|exact_match|↑ |0.7566|± |0.0118|
Transformers impl:
lm_eval --model vllm --model_args pretrained=meta-llama/Llama-3.1-8B-Instruct,model_impl=transformers --tasks gsm8k --num_fewshot 5 --batch_size auto
...
Processed prompts: 100%|█████████████| 1319/1319 [00:54<00:00, 24.36it/s, est. speed input: 21247.29 toks/s, output: 2378.45 toks/s]
Running generate_until requests: 100%|███████████████| 1319/1319 [00:54<00:00, 24.25it/s]
2025-01-31:20:30:10,932 INFO [evaluation_tracker.py:269] Output path not provided, skipping saving results aggregated
vllm (pretrained=meta-llama/Llama-3.1-8B-Instruct,model_impl=transformers), gen_kwargs: (None), limit: None, num_fewshot: 5, batch_size: auto
|Tasks|Version| Filter |n-shot| Metric | |Value | |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k| 3|flexible-extract| 5|exact_match|↑ |0.7763|± |0.0115|
| | |strict-match | 5|exact_match|↑ |0.7521|± |0.0119|
One thing I would like to note is that it seems V1 is not supported yet. Running VLLM_USE_V1=1
with Llama results in an error about the input_embeds forward pass arg. This is only used for multimodal models currently so we could get around this for this case by ignoring it.
ERROR 01-31 20:27:32 core.py:208] File "/home/mgoin/code/vllm/vllm/v1/worker/gpu_model_runner.py", line 870, in _dummy_run
ERROR 01-31 20:27:32 core.py:208] hidden_states = model(
ERROR 01-31 20:27:32 core.py:208] ^^^^^^
ERROR 01-31 20:27:32 core.py:208] File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
ERROR 01-31 20:27:32 core.py:208] return self._call_impl(*args, **kwargs)
ERROR 01-31 20:27:32 core.py:208] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 01-31 20:27:32 core.py:208] File "/home/mgoin/venvs/vllm/lib/python3.12/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
ERROR 01-31 20:27:32 core.py:208] return forward_call(*args, **kwargs)
ERROR 01-31 20:27:32 core.py:208] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ERROR 01-31 20:27:32 core.py:208] TypeError: TransformersModel.forward() got an unexpected keyword argument 'inputs_embeds'
#12599 has been merged, can you merge from main to fix the merge conflicts? |
This pull request has merge conflicts that must be resolved before it can be |
Signed-off-by: Isotr0py <[email protected]>
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's get this merged first! We can add BNB and LoRA support in other following PR.
Signed-off-by: Isotr0py <[email protected]>
Please fix the failing tests |
Signed-off-by: Isotr0py <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Please also add the distributed transformers test to the distributed tests CI |
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
# Adds support for `transformers` as a backend Following huggingface/transformers#35235, a bunch of models should already be supported, we are ramping up support for more models. Thanks @Isotr0py for the TP support, and @hmellor for his help as well! This includes: - `trust_remote_code=True` support: any model on the hub, if it implements attention the correct way can be natively supported!! - tensor parallel support --------- Signed-off-by: Harry Mellor <[email protected]> Signed-off-by: Isotr0py <[email protected]> Co-authored-by: Isotr0py <[email protected]> Co-authored-by: Harry Mellor <[email protected]> Co-authored-by: Isotr0py <[email protected]> Co-authored-by: Cyrus Leung <[email protected]> Co-authored-by: Michael Goin <[email protected]> Co-authored-by: Isotr0py <[email protected]> Signed-off-by: Felix Marty <[email protected]>
Adds support for
transformers
as a backendFollowing huggingface/transformers#35235, a bunch of models should already be supported, we are ramping up support for more models.
Thanks @Isotr0py for the TP support, and @hmellor for his help as well!
This includes:
trust_remote_code=True
support: any model on the hub, if it implements attention the correct way can be natively supported!!