Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Misc] Move registry to its own file #9064

Merged
merged 2 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/models/adding_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ This method should load the weights from the HuggingFace's checkpoint file and a
5. Register your model
----------------------

Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/__init__.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/__init__.py>`_.
Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/registry.py <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/registry.py>`_.

6. Out-of-Tree Model Integration
--------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@
import pytest
import torch.cuda

from vllm.model_executor.models import _MODELS, ModelRegistry
from vllm.model_executor.models import ModelRegistry
from vllm.platforms import current_platform

from ..utils import fork_new_process_for_each_test


@pytest.mark.parametrize("model_arch", _MODELS)
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
def test_registry_imports(model_arch):
# Ensure all model classes can be imported successfully
ModelRegistry.resolve_model_cls(model_arch)
Expand Down
3 changes: 1 addition & 2 deletions vllm/lora/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@
from vllm.lora.punica import PunicaWrapper
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models.interfaces import (SupportsLoRA,
supports_multimodal)
from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer
from vllm.utils import is_pin_memory_available
Expand Down
5 changes: 2 additions & 3 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,8 @@
get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator,
initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator,
safetensors_weights_iterator)
from vllm.model_executor.models.interfaces import (has_inner_state,
supports_lora,
supports_multimodal)
from vllm.model_executor.models import (has_inner_state, supports_lora,
supports_multimodal)
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.utils import is_pin_memory_available
Expand Down
333 changes: 12 additions & 321 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,325 +1,16 @@
import importlib
import string
import subprocess
import sys
import uuid
from functools import lru_cache, partial
from typing import Callable, Dict, List, Optional, Tuple, Type, Union

import torch.nn as nn

from vllm.logger import init_logger
from vllm.utils import is_hip

from .interfaces import supports_multimodal, supports_pp

logger = init_logger(__name__)

_GENERATION_MODELS = {
"AquilaModel": ("llama", "LlamaForCausalLM"),
"AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2
"ArcticForCausalLM": ("arctic", "ArcticForCausalLM"),
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
"CohereForCausalLM": ("commandr", "CohereForCausalLM"),
"DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"),
"DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"),
"DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"),
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"),
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
"GemmaForCausalLM": ("gemma", "GemmaForCausalLM"),
"Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"),
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
"GraniteForCausalLM": ("granite", "GraniteForCausalLM"),
"GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"),
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
# For decapoda-research/llama-*
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
"MistralForCausalLM": ("llama", "LlamaForCausalLM"),
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
"QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"),
# transformers's mpt class has lower case
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
"MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"),
"MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"),
"NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"),
"OlmoForCausalLM": ("olmo", "OlmoForCausalLM"),
"OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"),
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
"PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
"Qwen2VLForConditionalGeneration":
("qwen2_vl", "Qwen2VLForConditionalGeneration"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"),
"Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"),
"SolarForCausalLM": ("solar", "SolarForCausalLM"),
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
# NOTE: The below models are for speculative decoding only
"MedusaModel": ("medusa", "Medusa"),
"EAGLEModel": ("eagle", "EAGLE"),
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
}

_EMBEDDING_MODELS = {
"MistralModel": ("llama_embedding", "LlamaEmbeddingModel"),
"Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),
}

_MULTIMODAL_MODELS = {
"Blip2ForConditionalGeneration":
("blip2", "Blip2ForConditionalGeneration"),
"ChameleonForConditionalGeneration":
("chameleon", "ChameleonForConditionalGeneration"),
"FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"),
"InternVLChatModel": ("internvl", "InternVLChatModel"),
"LlavaForConditionalGeneration": ("llava",
"LlavaForConditionalGeneration"),
"LlavaNextForConditionalGeneration": ("llava_next",
"LlavaNextForConditionalGeneration"),
"LlavaNextVideoForConditionalGeneration":
("llava_next_video", "LlavaNextVideoForConditionalGeneration"),
"LlavaOnevisionForConditionalGeneration":
("llava_onevision", "LlavaOnevisionForConditionalGeneration"),
"MiniCPMV": ("minicpmv", "MiniCPMV"),
"PaliGemmaForConditionalGeneration": ("paligemma",
"PaliGemmaForConditionalGeneration"),
"Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),
"PixtralForConditionalGeneration": ("pixtral",
"PixtralForConditionalGeneration"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2VLForConditionalGeneration": ("qwen2_vl",
"Qwen2VLForConditionalGeneration"),
"UltravoxModel": ("ultravox", "UltravoxModel"),
"MllamaForConditionalGeneration": ("mllama",
"MllamaForConditionalGeneration"),
}
_CONDITIONAL_GENERATION_MODELS = {
"BartModel": ("bart", "BartForConditionalGeneration"),
"BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"),
}

_MODELS = {
**_GENERATION_MODELS,
**_EMBEDDING_MODELS,
**_MULTIMODAL_MODELS,
**_CONDITIONAL_GENERATION_MODELS,
}

# Architecture -> type.
# out of tree models
_OOT_MODELS: Dict[str, Type[nn.Module]] = {}

# Models not supported by ROCm.
_ROCM_UNSUPPORTED_MODELS: List[str] = []

# Models partially supported by ROCm.
# Architecture -> Reason.
_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in "
"Triton flash attention. For half-precision SWA support, "
"please use CK flash attention by setting "
"`VLLM_USE_TRITON_FLASH_ATTN=0`")
_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = {
"Qwen2ForCausalLM":
_ROCM_SWA_REASON,
"MistralForCausalLM":
_ROCM_SWA_REASON,
"MixtralForCausalLM":
_ROCM_SWA_REASON,
"PaliGemmaForConditionalGeneration":
("ROCm flash attention does not yet "
"fully support 32-bit precision on PaliGemma"),
"Phi3VForCausalLM":
("ROCm Triton flash attention may run into compilation errors due to "
"excessive use of shared memory. If this happens, disable Triton FA "
"by setting `VLLM_USE_TRITON_FLASH_ATTN=0`")
}


class ModelRegistry:

@staticmethod
def _get_module_cls_name(model_arch: str) -> Tuple[str, str]:
module_relname, cls_name = _MODELS[model_arch]
return f"vllm.model_executor.models.{module_relname}", cls_name

@staticmethod
@lru_cache(maxsize=128)
def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch not in _MODELS:
return None

module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)
module = importlib.import_module(module_name)
return getattr(module, cls_name, None)

@staticmethod
def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]:
if model_arch in _OOT_MODELS:
return _OOT_MODELS[model_arch]

if is_hip():
if model_arch in _ROCM_UNSUPPORTED_MODELS:
raise ValueError(
f"Model architecture {model_arch} is not supported by "
"ROCm for now.")
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
logger.warning(
"Model architecture %s is partially supported by ROCm: %s",
model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])

return None

@staticmethod
def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
model = ModelRegistry._try_get_model_stateless(model_arch)
if model is not None:
return model

return ModelRegistry._try_get_model_stateful(model_arch)

@staticmethod
def resolve_model_cls(
architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")

for arch in architectures:
model_cls = ModelRegistry._try_load_model_cls(arch)
if model_cls is not None:
return (model_cls, arch)

raise ValueError(
f"Model architectures {architectures} are not supported for now. "
f"Supported architectures: {ModelRegistry.get_supported_archs()}")

@staticmethod
def get_supported_archs() -> List[str]:
return list(_MODELS.keys()) + list(_OOT_MODELS.keys())

@staticmethod
def register_model(model_arch: str, model_cls: Type[nn.Module]):
if model_arch in _MODELS:
logger.warning(
"Model architecture %s is already registered, and will be "
"overwritten by the new model class %s.", model_arch,
model_cls.__name__)

_OOT_MODELS[model_arch] = model_cls

@staticmethod
@lru_cache(maxsize=128)
def _check_stateless(
func: Callable[[Type[nn.Module]], bool],
model_arch: str,
*,
default: Optional[bool] = None,
) -> bool:
"""
Run a boolean function against a model and return the result.
If the model is not found, returns the provided default value.
If the model is not already imported, the function is run inside a
subprocess to avoid initializing CUDA for the main program.
"""
model = ModelRegistry._try_get_model_stateless(model_arch)
if model is not None:
return func(model)

if model_arch not in _MODELS and default is not None:
return default

module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch)

valid_name_characters = string.ascii_letters + string.digits + "._"
if any(s not in valid_name_characters for s in module_name):
raise ValueError(f"Unsafe module name detected for {model_arch}")
if any(s not in valid_name_characters for s in cls_name):
raise ValueError(f"Unsafe class name detected for {model_arch}")
if any(s not in valid_name_characters for s in func.__module__):
raise ValueError(f"Unsafe module name detected for {func}")
if any(s not in valid_name_characters for s in func.__name__):
raise ValueError(f"Unsafe class name detected for {func}")

err_id = uuid.uuid4()

stmts = ";".join([
f"from {module_name} import {cls_name}",
f"from {func.__module__} import {func.__name__}",
f"assert {func.__name__}({cls_name}), '{err_id}'",
])

result = subprocess.run([sys.executable, "-c", stmts],
capture_output=True)

if result.returncode != 0:
err_lines = [line.decode() for line in result.stderr.splitlines()]
if err_lines and err_lines[-1] != f"AssertionError: {err_id}":
err_str = "\n".join(err_lines)
raise RuntimeError(
"An unexpected error occurred while importing the model in "
f"another process. Error log:\n{err_str}")

return result.returncode == 0

@staticmethod
def is_embedding_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")

return any(arch in _EMBEDDING_MODELS for arch in architectures)

@staticmethod
def is_multimodal_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")

is_mm = partial(ModelRegistry._check_stateless,
supports_multimodal,
default=False)

return any(is_mm(arch) for arch in architectures)

@staticmethod
def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool:
if isinstance(architectures, str):
architectures = [architectures]
if not architectures:
logger.warning("No model architectures are specified")

is_pp = partial(ModelRegistry._check_stateless,
supports_pp,
default=False)

return any(is_pp(arch) for arch in architectures)

from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal,
SupportsPP, has_inner_state, supports_lora,
supports_multimodal, supports_pp)
from .registry import ModelRegistry

__all__ = [
"ModelRegistry",
"HasInnerState",
"has_inner_state",
"SupportsLoRA",
"supports_lora",
"SupportsMultiModal",
"supports_multimodal",
"SupportsPP",
"supports_pp",
]
Loading
Loading