Skip to content

Commit

Permalink
[Model]: Add transformers backend support (vllm-project#11330)
Browse files Browse the repository at this point in the history
# Adds support for `transformers` as a backend

Following huggingface/transformers#35235, a
bunch of models should already be supported, we are ramping up support
for more models.

Thanks @Isotr0py for the TP support, and @hmellor for his help as well!
This includes: 
- `trust_remote_code=True` support: any model on the hub, if it
implements attention the correct way can be natively supported!!
- tensor parallel support

---------

Signed-off-by: Harry Mellor <[email protected]>
Signed-off-by: Isotr0py <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: Harry Mellor <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
Co-authored-by: Cyrus Leung <[email protected]>
Co-authored-by: Michael Goin <[email protected]>
Co-authored-by: Isotr0py <[email protected]>
  • Loading branch information
7 people authored Feb 3, 2025
1 parent 1298a40 commit a1a2aaa
Show file tree
Hide file tree
Showing 11 changed files with 528 additions and 9 deletions.
2 changes: 2 additions & 0 deletions .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,7 @@ steps:
- vllm/
- tests/models
commands:
- pytest -v -s models/test_transformers.py
- pytest -v -s models/test_registry.py
- pytest -v -s models/test_initialization.py

Expand Down Expand Up @@ -485,6 +486,7 @@ steps:
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
- TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)'
# Avoid importing model tests that cause CUDA reinitialization error
- pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)'
- pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)'
Expand Down
76 changes: 76 additions & 0 deletions docs/source/models/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,82 @@ If vLLM successfully returns text (for generative models) or hidden states (for
Otherwise, please refer to [Adding a New Model](#new-model) for instructions on how to implement your model in vLLM.
Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) to request vLLM support.

### Transformers fallback

After the merge of <gh-pr:11330>, `vllm` can fallback to models that are available in `transformers`. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned!

To check if the backend is `transformers`, you can simply do this:

```python
from vllm import LLM
llm = LLM(model=..., task="generate") # Name or path of your model
llm.apply_model(lambda model: print(model.__class__))
```

If it is `TransformersModel` then it means it's based on `transformers`!

#### Supported features

##### LORA and quantization

Both are not supported yet! Make sure to open an issue and we'll work on this together with the `transformers` team!

Usually `transformers` model load weights via the `load_adapters` API, that depends on PEFT. We need to work a bit to either use this api (for now this would result in some weights not being marked as loaded) or replace modules accordingly.

Hints as to how this would look like:

```python
class TransformersModel(nn.Module, SupportsLoRA):
def __init__(*):
...
self.model.load_adapter(vllm_config.load_config.model_loader_extra_config["qlora_adapter_name_or_path"])
```

Blocker is that you need to specify supported lora layers, when we would ideally want to load whatever is inside the checkpoint!

##### Remote code

This fallback also means that any model on the hub that can be used in `transformers` with `trust_remote_code=True` that correctly implements attention can be used in production!

```python
from vllm import LLM
llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model
llm.apply_model(lambda model: print(model.__class__))
```

A model just needs the following two things:

```python
from transformers import PreTrainedModel
from torch import nn

class MyAttention(nn.Module):

def forward(self, hidden_states, **kwargs): # <- kwargs are required

...
attention_interface = attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
attn_output, attn_weights = attention_interface(
self,
query_states,
key_states,
value_states,
**kwargs,
)
...

class MyModel(PreTrainedModel):
_supports_attention_backend = True
```

Here is what happens in the background:

1. The config is loaded
2. `MyModel` python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`.
3. The `TransformersModel` backend is used. See `/model_executors/models/transformers`, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`.

That's it!

### ModelScope

To use models from [ModelScope](https://www.modelscope.cn) instead of HuggingFace Hub, set an environment variable:
Expand Down
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ requests >= 2.26.0
tqdm
blake3
py-cpuinfo
transformers >= 4.48.2 # Required for Bamba.
transformers >= 4.48.2 # Required for Bamba model and Transformers backend.
tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer.
fastapi >= 0.107.0, < 0.113.0; python_version < '3.9'
Expand Down
5 changes: 5 additions & 0 deletions tests/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,17 @@ def check_available_online(
speculative_model="ibm-fms/llama-160m-accelerator"), # noqa: E501
}

_FALLBACK_MODEL = {
"TransformersModel": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501
}

_EXAMPLE_MODELS = {
**_TEXT_GENERATION_EXAMPLE_MODELS,
**_EMBEDDING_EXAMPLE_MODELS,
**_CROSS_ENCODER_EXAMPLE_MODELS,
**_MULTIMODAL_EXAMPLE_MODELS,
**_SPECULATIVE_DECODING_EXAMPLE_MODELS,
**_FALLBACK_MODEL,
}


Expand Down
4 changes: 3 additions & 1 deletion tests/models/test_oot_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ def test_plugin(dummy_opt_path):
os.environ["VLLM_PLUGINS"] = ""
with pytest.raises(Exception) as excinfo:
LLM(model=dummy_opt_path, load_format="dummy")
assert "are not supported for now" in str(excinfo.value)
error_msg = "has no vLLM implementation and " \
"the Transformers implementation is not compatible with vLLM."
assert (error_msg in str(excinfo.value))


@fork_new_process_for_each_test
Expand Down
75 changes: 75 additions & 0 deletions tests/models/test_transformers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Test the functionality of the Transformers backend.
Run `pytest tests/models/test_transformers.py`.
"""
from contextlib import nullcontext
from typing import Type

import pytest

from ..conftest import HfRunner, VllmRunner
from ..utils import multi_gpu_test
from .utils import check_logprobs_close


def check_implementation(
hf_runner: Type[HfRunner],
vllm_runner: Type[VllmRunner],
example_prompts: list[str],
model: str,
**kwargs,
):
max_tokens = 32
num_logprobs = 5

with vllm_runner(model, **kwargs) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)

check_logprobs_close(
outputs_0_lst=hf_outputs,
outputs_1_lst=vllm_outputs,
name_0="hf",
name_1="vllm",
)


@pytest.mark.parametrize(
"model,model_impl",
[
("meta-llama/Llama-3.2-1B-Instruct", "transformers"),
("openai-community/gpt2", "transformers"),
("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE
("meta-llama/Llama-3.2-1B-Instruct", "auto"),
]) # trust_remote_code=True by default
def test_models(hf_runner, vllm_runner, example_prompts, model,
model_impl) -> None:

maybe_raises = nullcontext()
if model == "openai-community/gpt2" and model_impl == "transformers":
# Model is not backend compatible
maybe_raises = pytest.raises(
ValueError,
match="The Transformers implementation.*not compatible with vLLM")

with maybe_raises:
check_implementation(hf_runner,
vllm_runner,
example_prompts,
model,
model_impl=model_impl)


@multi_gpu_test(num_gpus=2)
def test_distributed(
hf_runner,
vllm_runner,
example_prompts,
):
kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2}
check_implementation(hf_runner, vllm_runner, example_prompts,
"meta-llama/Llama-3.2-1B-Instruct", **kwargs)
14 changes: 14 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def compute_hash(self) -> str:
...


class ModelImpl(str, enum.Enum):
AUTO = "auto"
VLLM = "vllm"
TRANSFORMERS = "transformers"


class ModelConfig:
"""Configuration for the model.
Expand Down Expand Up @@ -167,6 +173,12 @@ class ModelConfig:
`logits_processors` extra completion argument. Defaults to None,
which allows no processors.
generation_config: Configuration parameter file for generation.
model_impl: Which implementation of the model to use:
"auto" will try to use the vLLM implementation if it exists and
fall back to the Transformers implementation if no vLLM
implementation is available.
"vllm" will use the vLLM model implementation.
"transformers" will use the Transformers model implementation.
override_generation_config: Override the generation config with the
given config.
"""
Expand Down Expand Up @@ -230,6 +242,7 @@ def __init__(
generation_config: Optional[str] = None,
enable_sleep_mode: bool = False,
override_generation_config: Optional[Dict[str, Any]] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
) -> None:
self.model = model
self.tokenizer = tokenizer
Expand All @@ -241,6 +254,7 @@ def __init__(
self.code_revision = code_revision
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
self.model_impl = model_impl

if hf_overrides is None:
hf_overrides = {}
Expand Down
22 changes: 18 additions & 4 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
DecodingConfig, DeviceConfig, HfOverrides,
KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PoolerConfig, PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
VllmConfig)
ModelConfig, ModelImpl, ObservabilityConfig,
ParallelConfig, PoolerConfig, PromptAdapterConfig,
SchedulerConfig, SpeculativeConfig, TaskOption,
TokenizerPoolConfig, VllmConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
Expand Down Expand Up @@ -199,6 +199,7 @@ class EngineArgs:
generation_config: Optional[str] = None
override_generation_config: Optional[Dict[str, Any]] = None
enable_sleep_mode: bool = False
model_impl: str = "auto"

calculate_kv_scales: Optional[bool] = None

Expand Down Expand Up @@ -378,6 +379,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
'qualified names that can be passed with the `logits_processors` '
'extra completion argument. Defaults to None, which allows no '
'processors.')
parser.add_argument(
'--model-impl',
type=str,
default=EngineArgs.model_impl,
choices=[f.value for f in ModelImpl],
help='Which implementation of the model to use.\n\n'
'* "auto" will try to use the vLLM implementation if it exists '
'and fall back to the Transformers implementation if no vLLM '
'implementation is available.\n'
'* "vllm" will use the vLLM model implementation.\n'
'* "transformers" will use the Transformers model '
'implementation.\n')
# Parallel arguments
parser.add_argument(
'--distributed-executor-backend',
Expand Down Expand Up @@ -1017,6 +1030,7 @@ def create_model_config(self) -> ModelConfig:
generation_config=self.generation_config,
override_generation_config=self.override_generation_config,
enable_sleep_mode=self.enable_sleep_mode,
model_impl=self.model_impl,
)

def create_load_config(self) -> LoadConfig:
Expand Down
61 changes: 59 additions & 2 deletions vllm/model_executor/model_loader/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@
"""Utilities for selecting and loading models."""
import contextlib
from dataclasses import dataclass, field
from typing import Dict, List, Tuple, Type
from typing import Dict, List, Optional, Tuple, Type

import torch
import transformers
from torch import nn
from transformers.dynamic_module_utils import get_class_from_dynamic_module

from vllm.config import ModelConfig
from vllm.config import ModelConfig, ModelImpl
from vllm.logger import init_logger
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.adapters import (as_classification_model,
as_embedding_model,
as_reward_model)

logger = init_logger(__name__)


@contextlib.contextmanager
def set_default_torch_dtype(dtype: torch.dtype):
Expand All @@ -23,6 +28,50 @@ def set_default_torch_dtype(dtype: torch.dtype):
torch.set_default_dtype(old_dtype)


def is_transformers_impl_compatible(
arch: str,
module: Optional[transformers.PreTrainedModel] = None) -> bool:
mod = module or getattr(transformers, arch, None)
if mod is None:
return False
if hasattr(mod, "supports_backend"):
return mod.is_backend_compatible()
else:
return mod._supports_flex_attn


def resolve_transformers_fallback(model_config: ModelConfig,
architectures: list[str]):
for i, arch in enumerate(architectures):
if arch == "TransformersModel":
continue
custom_module = None
auto_map = getattr(model_config.hf_config, "auto_map", None)
if auto_map is not None and "AutoModel" in auto_map:
custom_module = get_class_from_dynamic_module(
model_config.hf_config.auto_map["AutoModel"],
model_config.model)
# TODO(Isotr0py): Further clean up these raises.
# perhaps handled them in _ModelRegistry._raise_for_unsupported?
if model_config.model_impl == ModelImpl.TRANSFORMERS:
if not is_transformers_impl_compatible(arch, custom_module):
raise ValueError(
f"The Transformers implementation of {arch} is not "
"compatible with vLLM.")
architectures[i] = "TransformersModel"
if model_config.model_impl == ModelImpl.AUTO:
if not is_transformers_impl_compatible(arch, custom_module):
raise ValueError(
f"{arch} has no vLLM implementation and the Transformers "
"implementation is not compatible with vLLM.")
logger.warning(
"%s has no vLLM implementation, falling back to Transformers "
"implementation. Some features may not be supported and "
"performance may not be optimal.", arch)
architectures[i] = "TransformersModel"
return architectures


def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
Expand All @@ -38,6 +87,14 @@ def get_model_architecture(
and "MixtralForCausalLM" in architectures):
architectures = ["QuantMixtralForCausalLM"]

vllm_supported_archs = ModelRegistry.get_supported_archs()
is_vllm_supported = any(arch in vllm_supported_archs
for arch in architectures)
if (not is_vllm_supported
or model_config.model_impl == ModelImpl.TRANSFORMERS):
architectures = resolve_transformers_fallback(model_config,
architectures)

model_cls, arch = ModelRegistry.resolve_model_cls(architectures)
if model_config.task == "embed":
model_cls = as_embedding_model(model_cls)
Expand Down
Loading

0 comments on commit a1a2aaa

Please sign in to comment.