diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 05f3c3b314d1a..a847a68a6ef71 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -349,6 +349,7 @@ steps: - vllm/ - tests/models commands: + - pytest -v -s models/test_transformers.py - pytest -v -s models/test_registry.py - pytest -v -s models/test_initialization.py @@ -485,6 +486,7 @@ steps: - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m 'distributed(num_gpus=2)' # Avoid importing model tests that cause CUDA reinitialization error + - pytest models/test_transformers.py -v -s -m 'distributed(num_gpus=2)' - pytest models/encoder_decoder/language/test_bart.py -v -s -m 'distributed(num_gpus=2)' - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m 'distributed(num_gpus=2)' - pytest models/decoder_only/vision_language/test_models.py -v -s -m 'distributed(num_gpus=2)' diff --git a/docs/source/models/supported_models.md b/docs/source/models/supported_models.md index afaad8818bdcb..4a099646964f2 100644 --- a/docs/source/models/supported_models.md +++ b/docs/source/models/supported_models.md @@ -40,6 +40,82 @@ If vLLM successfully returns text (for generative models) or hidden states (for Otherwise, please refer to [Adding a New Model](#new-model) for instructions on how to implement your model in vLLM. Alternatively, you can [open an issue on GitHub](https://github.com/vllm-project/vllm/issues/new/choose) to request vLLM support. +### Transformers fallback + +After the merge of , `vllm` can fallback to models that are available in `transformers`. This does not work for all models for now, but most decoder language models are supported, and vision language model support is planned! + +To check if the backend is `transformers`, you can simply do this: + +```python +from vllm import LLM +llm = LLM(model=..., task="generate") # Name or path of your model +llm.apply_model(lambda model: print(model.__class__)) +``` + +If it is `TransformersModel` then it means it's based on `transformers`! + +#### Supported features + +##### LORA and quantization + +Both are not supported yet! Make sure to open an issue and we'll work on this together with the `transformers` team! + +Usually `transformers` model load weights via the `load_adapters` API, that depends on PEFT. We need to work a bit to either use this api (for now this would result in some weights not being marked as loaded) or replace modules accordingly. + +Hints as to how this would look like: + +```python +class TransformersModel(nn.Module, SupportsLoRA): + def __init__(*): + ... + self.model.load_adapter(vllm_config.load_config.model_loader_extra_config["qlora_adapter_name_or_path"]) +``` + +Blocker is that you need to specify supported lora layers, when we would ideally want to load whatever is inside the checkpoint! + +##### Remote code + +This fallback also means that any model on the hub that can be used in `transformers` with `trust_remote_code=True` that correctly implements attention can be used in production! + +```python +from vllm import LLM +llm = LLM(model=..., task="generate", trust_remote_code=True) # Name or path of your model +llm.apply_model(lambda model: print(model.__class__)) +``` + +A model just needs the following two things: + +```python +from transformers import PreTrainedModel +from torch import nn + +class MyAttention(nn.Module): + + def forward(self, hidden_states, **kwargs): # <- kwargs are required + + ... + attention_interface = attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + **kwargs, + ) + ... + +class MyModel(PreTrainedModel): + _supports_attention_backend = True +``` + +Here is what happens in the background: + +1. The config is loaded +2. `MyModel` python class is loaded from the `auto_map`, and we check that the model `_supports_attention_backend`. +3. The `TransformersModel` backend is used. See `/model_executors/models/transformers`, which leverage `self.config._attn_implementation = "vllm"`, thus the need to use `ALL_ATTENTION_FUNCTION`. + +That's it! + ### ModelScope To use models from [ModelScope](https://www.modelscope.cn) instead of HuggingFace Hub, set an environment variable: diff --git a/requirements-common.txt b/requirements-common.txt index e5248572ce4d4..97e33a6dbd880 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -5,7 +5,7 @@ requests >= 2.26.0 tqdm blake3 py-cpuinfo -transformers >= 4.48.2 # Required for Bamba. +transformers >= 4.48.2 # Required for Bamba model and Transformers backend. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. fastapi >= 0.107.0, < 0.113.0; python_version < '3.9' diff --git a/tests/models/registry.py b/tests/models/registry.py index d0dbbf00e0c51..8a0ade4fa2074 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -281,12 +281,17 @@ def check_available_online( speculative_model="ibm-fms/llama-160m-accelerator"), # noqa: E501 } +_FALLBACK_MODEL = { + "TransformersModel": _HfExamplesInfo("ArthurZ/Ilama-3.2-1B", trust_remote_code=True), # noqa: E501 +} + _EXAMPLE_MODELS = { **_TEXT_GENERATION_EXAMPLE_MODELS, **_EMBEDDING_EXAMPLE_MODELS, **_CROSS_ENCODER_EXAMPLE_MODELS, **_MULTIMODAL_EXAMPLE_MODELS, **_SPECULATIVE_DECODING_EXAMPLE_MODELS, + **_FALLBACK_MODEL, } diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index ef665baa1804d..f2a505596ce69 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -15,7 +15,9 @@ def test_plugin(dummy_opt_path): os.environ["VLLM_PLUGINS"] = "" with pytest.raises(Exception) as excinfo: LLM(model=dummy_opt_path, load_format="dummy") - assert "are not supported for now" in str(excinfo.value) + error_msg = "has no vLLM implementation and " \ + "the Transformers implementation is not compatible with vLLM." + assert (error_msg in str(excinfo.value)) @fork_new_process_for_each_test diff --git a/tests/models/test_transformers.py b/tests/models/test_transformers.py new file mode 100644 index 0000000000000..c6536f37cbdc8 --- /dev/null +++ b/tests/models/test_transformers.py @@ -0,0 +1,75 @@ +"""Test the functionality of the Transformers backend. + +Run `pytest tests/models/test_transformers.py`. +""" +from contextlib import nullcontext +from typing import Type + +import pytest + +from ..conftest import HfRunner, VllmRunner +from ..utils import multi_gpu_test +from .utils import check_logprobs_close + + +def check_implementation( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + example_prompts: list[str], + model: str, + **kwargs, +): + max_tokens = 32 + num_logprobs = 5 + + with vllm_runner(model, **kwargs) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + + with hf_runner(model) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) + + +@pytest.mark.parametrize( + "model,model_impl", + [ + ("meta-llama/Llama-3.2-1B-Instruct", "transformers"), + ("openai-community/gpt2", "transformers"), + ("ArthurZ/Ilama-3.2-1B", "auto"), # CUSTOM CODE + ("meta-llama/Llama-3.2-1B-Instruct", "auto"), + ]) # trust_remote_code=True by default +def test_models(hf_runner, vllm_runner, example_prompts, model, + model_impl) -> None: + + maybe_raises = nullcontext() + if model == "openai-community/gpt2" and model_impl == "transformers": + # Model is not backend compatible + maybe_raises = pytest.raises( + ValueError, + match="The Transformers implementation.*not compatible with vLLM") + + with maybe_raises: + check_implementation(hf_runner, + vllm_runner, + example_prompts, + model, + model_impl=model_impl) + + +@multi_gpu_test(num_gpus=2) +def test_distributed( + hf_runner, + vllm_runner, + example_prompts, +): + kwargs = {"model_impl": "transformers", "tensor_parallel_size": 2} + check_implementation(hf_runner, vllm_runner, example_prompts, + "meta-llama/Llama-3.2-1B-Instruct", **kwargs) diff --git a/vllm/config.py b/vllm/config.py index d2d59c7059e94..d70a637956edf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -83,6 +83,12 @@ def compute_hash(self) -> str: ... +class ModelImpl(str, enum.Enum): + AUTO = "auto" + VLLM = "vllm" + TRANSFORMERS = "transformers" + + class ModelConfig: """Configuration for the model. @@ -167,6 +173,12 @@ class ModelConfig: `logits_processors` extra completion argument. Defaults to None, which allows no processors. generation_config: Configuration parameter file for generation. + model_impl: Which implementation of the model to use: + "auto" will try to use the vLLM implementation if it exists and + fall back to the Transformers implementation if no vLLM + implementation is available. + "vllm" will use the vLLM model implementation. + "transformers" will use the Transformers model implementation. override_generation_config: Override the generation config with the given config. """ @@ -230,6 +242,7 @@ def __init__( generation_config: Optional[str] = None, enable_sleep_mode: bool = False, override_generation_config: Optional[Dict[str, Any]] = None, + model_impl: Union[str, ModelImpl] = ModelImpl.AUTO, ) -> None: self.model = model self.tokenizer = tokenizer @@ -241,6 +254,7 @@ def __init__( self.code_revision = code_revision self.rope_scaling = rope_scaling self.rope_theta = rope_theta + self.model_impl = model_impl if hf_overrides is None: hf_overrides = {} diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7c0e8c214066f..40c6fb4567993 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -13,10 +13,10 @@ from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat, DecodingConfig, DeviceConfig, HfOverrides, KVTransferConfig, LoadConfig, LoadFormat, LoRAConfig, - ModelConfig, ObservabilityConfig, ParallelConfig, - PoolerConfig, PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig, TaskOption, TokenizerPoolConfig, - VllmConfig) + ModelConfig, ModelImpl, ObservabilityConfig, + ParallelConfig, PoolerConfig, PromptAdapterConfig, + SchedulerConfig, SpeculativeConfig, TaskOption, + TokenizerPoolConfig, VllmConfig) from vllm.executor.executor_base import ExecutorBase from vllm.logger import init_logger from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS @@ -199,6 +199,7 @@ class EngineArgs: generation_config: Optional[str] = None override_generation_config: Optional[Dict[str, Any]] = None enable_sleep_mode: bool = False + model_impl: str = "auto" calculate_kv_scales: Optional[bool] = None @@ -378,6 +379,18 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'qualified names that can be passed with the `logits_processors` ' 'extra completion argument. Defaults to None, which allows no ' 'processors.') + parser.add_argument( + '--model-impl', + type=str, + default=EngineArgs.model_impl, + choices=[f.value for f in ModelImpl], + help='Which implementation of the model to use.\n\n' + '* "auto" will try to use the vLLM implementation if it exists ' + 'and fall back to the Transformers implementation if no vLLM ' + 'implementation is available.\n' + '* "vllm" will use the vLLM model implementation.\n' + '* "transformers" will use the Transformers model ' + 'implementation.\n') # Parallel arguments parser.add_argument( '--distributed-executor-backend', @@ -1017,6 +1030,7 @@ def create_model_config(self) -> ModelConfig: generation_config=self.generation_config, override_generation_config=self.override_generation_config, enable_sleep_mode=self.enable_sleep_mode, + model_impl=self.model_impl, ) def create_load_config(self) -> LoadConfig: diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 084ca53b123db..eb334c1fdf255 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -2,17 +2,22 @@ """Utilities for selecting and loading models.""" import contextlib from dataclasses import dataclass, field -from typing import Dict, List, Tuple, Type +from typing import Dict, List, Optional, Tuple, Type import torch +import transformers from torch import nn +from transformers.dynamic_module_utils import get_class_from_dynamic_module -from vllm.config import ModelConfig +from vllm.config import ModelConfig, ModelImpl +from vllm.logger import init_logger from vllm.model_executor.models import ModelRegistry from vllm.model_executor.models.adapters import (as_classification_model, as_embedding_model, as_reward_model) +logger = init_logger(__name__) + @contextlib.contextmanager def set_default_torch_dtype(dtype: torch.dtype): @@ -23,6 +28,50 @@ def set_default_torch_dtype(dtype: torch.dtype): torch.set_default_dtype(old_dtype) +def is_transformers_impl_compatible( + arch: str, + module: Optional[transformers.PreTrainedModel] = None) -> bool: + mod = module or getattr(transformers, arch, None) + if mod is None: + return False + if hasattr(mod, "supports_backend"): + return mod.is_backend_compatible() + else: + return mod._supports_flex_attn + + +def resolve_transformers_fallback(model_config: ModelConfig, + architectures: list[str]): + for i, arch in enumerate(architectures): + if arch == "TransformersModel": + continue + custom_module = None + auto_map = getattr(model_config.hf_config, "auto_map", None) + if auto_map is not None and "AutoModel" in auto_map: + custom_module = get_class_from_dynamic_module( + model_config.hf_config.auto_map["AutoModel"], + model_config.model) + # TODO(Isotr0py): Further clean up these raises. + # perhaps handled them in _ModelRegistry._raise_for_unsupported? + if model_config.model_impl == ModelImpl.TRANSFORMERS: + if not is_transformers_impl_compatible(arch, custom_module): + raise ValueError( + f"The Transformers implementation of {arch} is not " + "compatible with vLLM.") + architectures[i] = "TransformersModel" + if model_config.model_impl == ModelImpl.AUTO: + if not is_transformers_impl_compatible(arch, custom_module): + raise ValueError( + f"{arch} has no vLLM implementation and the Transformers " + "implementation is not compatible with vLLM.") + logger.warning( + "%s has no vLLM implementation, falling back to Transformers " + "implementation. Some features may not be supported and " + "performance may not be optimal.", arch) + architectures[i] = "TransformersModel" + return architectures + + def get_model_architecture( model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: architectures = getattr(model_config.hf_config, "architectures", []) @@ -38,6 +87,14 @@ def get_model_architecture( and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] + vllm_supported_archs = ModelRegistry.get_supported_archs() + is_vllm_supported = any(arch in vllm_supported_archs + for arch in architectures) + if (not is_vllm_supported + or model_config.model_impl == ModelImpl.TRANSFORMERS): + architectures = resolve_transformers_fallback(model_config, + architectures) + model_cls, arch = ModelRegistry.resolve_model_cls(architectures) if model_config.task == "embed": model_cls = as_embedding_model(model_cls) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 40bbc7d16b81b..962f95f10fc51 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -184,6 +184,10 @@ "MedusaModel": ("medusa", "Medusa"), "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), } + +_FALLBACK_MODEL = { + "TransformersModel": ("transformers", "TransformersModel"), +} # yapf: enable _VLLM_MODELS = { @@ -192,6 +196,7 @@ **_CROSS_ENCODER_MODELS, **_MULTIMODAL_MODELS, **_SPECULATIVE_DECODING_MODELS, + **_FALLBACK_MODEL, } @@ -378,7 +383,12 @@ def _normalize_archs( if not architectures: logger.warning("No model architectures are specified") - return architectures + normalized_arch = [] + for model in architectures: + if model not in self.models: + model = "TransformersModel" + normalized_arch.append(model) + return normalized_arch def inspect_model_cls( self, diff --git a/vllm/model_executor/models/transformers.py b/vllm/model_executor/models/transformers.py new file mode 100644 index 0000000000000..ff1ae0ac85bac --- /dev/null +++ b/vllm/model_executor/models/transformers.py @@ -0,0 +1,264 @@ +# Copyright 2024 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Wrapper around `transformers` models""" +import re +from typing import Iterable, List, Optional, Set, Tuple, Union + +import torch +from torch import nn +from transformers import AutoModel, PreTrainedModel +from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import VllmConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed.utils import divide +from vllm.logger import init_logger +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from .utils import maybe_prefix + +logger = init_logger(__name__) + + +def vllm_flash_attention_forward( + # Transformers args + module: torch.nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: torch.Tensor, + # Transformers kwargs + scaling: float = None, + # vLLM kwargs + attn_metadata: AttentionMetadata = None, + attention_instances: list[Attention] = None, + **kwargs): + self_attn = attention_instances[module.layer_idx] + if scaling is not None: + self_attn.impl.scale = float(scaling) + hidden = query.shape[-2] + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) + query, key, value = (x.reshape(hidden, -1) for x in (query, key, value)) + return self_attn.forward( + query, + key, + value, + kv_cache=None, # argument not used + attn_metadata=attn_metadata), None + + +ALL_ATTENTION_FUNCTIONS["vllm"] = vllm_flash_attention_forward + + +# Linear Layer that is compatible with transformers internal forward +# TODO: This is a temporary solution, we should find a better way to integrate +class HFColumnParallelLinear(ColumnParallelLinear): + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return super().forward(input)[0] + + +class HFRowParallelLinear(RowParallelLinear): + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return super().forward(input)[0] + + +def replace_tp_linear_class(orig_module: nn.Linear, + style: str, + quant_config=None): + """ + In model configurations, we use a neutral type (string) to specify parallel + styles, here we use it to translate nn.Linear into vllm-style tp Linear. + + Quant config is not supported yet + """ + + if not isinstance(style, str): + raise ValueError( + f"Unsupported parallel style type {type(style)}, expected str") + + input_size = orig_module.in_features + output_size = orig_module.out_features + bias = orig_module.bias is not None + + if style == "colwise": + return HFColumnParallelLinear( + input_size, + output_size, + bias, + ) + elif style == "rowwise": + return HFRowParallelLinear( + input_size, + output_size, + bias, + ) + # We don't consider colwise_rep since it's used in lm_head + else: + raise ValueError(f"Unsupported parallel style value: {style}") + + +class TransformersModel(nn.Module): + embedding_padding_modules = ["lm_head"] + embedding_modules = ["embed_tokens" + ] # TODO transformers will have a util to get it + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: + super().__init__() + logger.info("Using Transformers backend.") + + self.vllm_config = vllm_config + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + self.quant_config = quant_config + self.config = config + self.vocab_size = config.vocab_size + self.unpadded_vocab_size = config.vocab_size + + self.model: PreTrainedModel = AutoModel.from_config( + self.config, + attn_implementation="vllm", + torch_dtype=vllm_config.model_config.dtype, + trust_remote_code=vllm_config.model_config.trust_remote_code, + ) + prefix = self.model.base_model_prefix + + # MLP modifications + self.tensor_parallelize(self.model) + + # Attention modifications (assumes 1 attention op per hidden layer) + tp_size = get_tensor_model_parallel_world_size() + self.attention_instances = [ + Attention( + num_heads=divide(config.num_attention_heads, tp_size), + head_size=config.head_dim, + # NOTE: We use Llama scale as default, if it's set by + # Transformers, it's updated in vllm_flash_attention_forward + scale=config.head_dim**-0.5, + num_kv_heads=divide(config.num_key_value_heads, tp_size), + cache_config=cache_config, + quant_config=None, + prefix=f"{i}.attn") for i in range(config.num_hidden_layers) + ] + + # Model modifications + self.replace_vocab_embed_class(self.model) + + # ForCausalLM modifications + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=None, + prefix=maybe_prefix(prefix, "lm_head")) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.get_input_embeddings().weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, logit_scale) + self.sampler = get_sampler() + + def log_replacement(self, name: str, old_module: nn.Module, + new_module: nn.Module): + logger.debug("%s: %s -> %s", name, old_module, new_module) + + def tensor_parallelize(self, module: nn.Module, prefix: str = ""): + if (self.config.base_model_tp_plan is None + and self.vllm_config.parallel_config.tensor_parallel_size > 1): + raise ValueError( + "Trying to run tensor parallelization but the model does not " + "support it yet!") + + for child_name, child_module in module.named_children(): + qual_name = prefix + child_name + for pattern, style in self.config.base_model_tp_plan.items(): + if re.match(pattern, qual_name) and isinstance( + child_module, nn.Linear): + new_module = replace_tp_linear_class( + child_module, style, self.quant_config) + setattr(module, child_name, new_module) + self.log_replacement(qual_name, child_module, new_module) + else: + self.tensor_parallelize(child_module, prefix=f"{qual_name}.") + + def replace_vocab_embed_class(self, module: nn.Module): + # Use native set input embeddings + new_module = VocabParallelEmbedding( + self.vocab_size, + self.config.hidden_size, + org_num_embeddings=self.config.vocab_size, + quant_config=None, + ) + self.log_replacement("input embedding", + self.model.get_input_embeddings(), new_module) + self.model.set_input_embeddings(new_module) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], # argument not used + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + model_output = self.model( + input_ids[None, ...], + use_cache=False, + position_ids=positions[None, ...], + attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, + attention_instances=self.attention_instances, + return_dict=False)[0][0, ...] # we remove batch dimension for now + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample(self, logits: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: + + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if name not in params_dict: + name = f"{self.model.base_model_prefix}.{name}" + if name in params_dict: + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params