From a744b5284543cd7489c4d05bb97ebad41a07dc73 Mon Sep 17 00:00:00 2001 From: Robert Shaw <114415538+robertgshaw2-neuralmagic@users.noreply.github.com> Date: Thu, 17 Oct 2024 19:21:01 -0400 Subject: [PATCH] Support `BERTModel` (first `encoder-only` embedding model) (#9056) Signed-off-by: Max de Bayser Signed-off-by: Max de Bayser Co-authored-by: Andrew Feldman Co-authored-by: afeldman-nm <156691304+afeldman-nm@users.noreply.github.com> Co-authored-by: Woosuk Kwon Co-authored-by: laishzh Co-authored-by: Max de Bayser Co-authored-by: Max de Bayser Co-authored-by: Cyrus Leung Signed-off-by: Tyler Michael Smith --- .../embedding/language/test_embedding.py | 14 +- vllm/attention/backends/abstract.py | 7 +- vllm/attention/backends/xformers.py | 59 ++- vllm/model_executor/layers/pooler.py | 12 +- vllm/model_executor/models/bert.py | 419 ++++++++++++++++++ vllm/model_executor/models/registry.py | 1 + 6 files changed, 497 insertions(+), 15 deletions(-) create mode 100644 vllm/model_executor/models/bert.py diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index 5f704d854e5dc..39b6bbaf43180 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -6,21 +6,31 @@ from ..utils import check_embeddings_close +# Model, Guard MODELS = [ "intfloat/e5-mistral-7b-instruct", + "BAAI/bge-base-en-v1.5", "BAAI/bge-multilingual-gemma2", ] +ENCODER_ONLY = [ + "BAAI/bge-base-en-v1.5", +] + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["half"]) def test_models( + monkeypatch, hf_runner, vllm_runner, example_prompts, - model: str, + model, dtype: str, ) -> None: + if model in ENCODER_ONLY: + monkeypatch.setenv("VLLM_ATTENTION_BACKEND", "XFORMERS") + # The example_prompts has ending "\n", for example: # "Write a short story about a robot that dreams for the first time.\n" # sentence_transformers will strip the input texts, see: @@ -33,7 +43,7 @@ def test_models( is_sentence_transformer=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, max_model_len=None) as vllm_model: vllm_outputs = vllm_model.encode(example_prompts) check_embeddings_close( diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index 2bc36ff18a96b..9ea89eca01f5b 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -15,8 +15,11 @@ class AttentionType(Enum): DECODER = auto() # Decoder attention between previous layer Q/K/V - ENCODER = auto() # Encoder attention between previous layer Q/K/V - ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V + ENCODER = auto( + ) # Encoder attention between previous layer Q/K/V for encoder-decoder + ENCODER_ONLY = auto() # Encoder attention between previous layer Q/K/V + ENCODER_DECODER = auto( + ) # Attention between dec. Q and enc. K/V for encoder-decoder class AttentionBackend(ABC): diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 25b86176f630e..650bc6ec7750a 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -287,13 +287,15 @@ def _get_attn_bias( * Appropriate attention bias value given the attention type ''' - if attn_type == AttentionType.DECODER: + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): return attn_metadata.attn_bias elif attn_type == AttentionType.ENCODER: return attn_metadata.encoder_attn_bias - else: - # attn_type == AttentionType.ENCODER_DECODER + elif attn_type == AttentionType.ENCODER_DECODER: return attn_metadata.cross_attn_bias + else: + raise AttributeError(f"Invalid attention type {str(attn_type)}") def _set_attn_bias( @@ -313,7 +315,8 @@ def _set_attn_bias( encoder/decoder cross-attention ''' - if attn_type == AttentionType.DECODER: + if (attn_type == AttentionType.DECODER + or attn_type == AttentionType.ENCODER_ONLY): attn_metadata.attn_bias = attn_bias elif attn_type == AttentionType.ENCODER: attn_metadata.encoder_attn_bias = attn_bias @@ -371,6 +374,12 @@ def _get_seq_len_block_table_args( # No block tables associated with encoder attention return (attn_metadata.encoder_seq_lens_tensor, attn_metadata.max_encoder_seq_len, None) + elif attn_type == AttentionType.ENCODER_ONLY: + assert is_prompt, "Should not have decode for encoder only model." + + # No block tables associated with encoder attention + return (attn_metadata.seq_lens_tensor, + attn_metadata.max_prefill_seq_len, None) else: raise AttributeError(f"Invalid attention type {str(attn_type)}") @@ -479,7 +488,10 @@ def forward( * ENCODER: no KV caching; pass encoder sequence attributes (encoder_seq_lens/encoder_seq_lens_tensor/ max_encoder_seq_len) to kernel, in lieu of decoder - sequence attributes (seq_lens/seq_lens_tensor/max_seq_len) + sequence attributes (seq_lens/seq_lens_tensor/max_seq_len). + Used for encoder branch of encoder-decoder models. + * ENCODER_ONLY: no kv_caching, uses the normal attention + attributes (seq_lens/seq_lens_tensor/max_seq_len). * ENCODER_DECODER: cross-attention behavior; use cross-attention block table for caching KVs derived from encoder hidden states; since KV sequence lengths @@ -509,6 +521,7 @@ def forward( and (not attn_metadata.is_all_encoder_attn_metadata_set)): raise AttributeError("Encoder attention requires setting " "encoder metadata attributes.") + elif (attn_type == AttentionType.ENCODER_DECODER and (not attn_metadata.is_all_cross_attn_metadata_set)): raise AttributeError("Encoder/decoder cross-attention " @@ -609,6 +622,8 @@ def forward( assert out.shape == output[:num_prefill_tokens].shape output[:num_prefill_tokens] = out else: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have prefix attention.") assert prefill_meta.query_start_loc is not None assert prefill_meta.max_query_len is not None @@ -638,6 +653,8 @@ def forward( output[:num_prefill_tokens] = out if decode_meta := attn_metadata.decode_metadata: + assert attn_type != AttentionType.ENCODER_ONLY, ( + "Encoder-only models should not have decode metadata.") ( seq_lens_arg, @@ -703,36 +720,60 @@ def _run_memory_efficient_xformers_forward( None, :].expand(value.shape[0], self.num_kv_heads, self.num_queries_per_kv, value.shape[-1]) + # Set attention bias if not provided. This typically happens at # the very attention layer of every iteration. # FIXME(woosuk): This is a hack. attn_bias = _get_attn_bias(attn_metadata, attn_type) if attn_bias is None: if self.alibi_slopes is None: + + # Cross attention block of decoder branch of encoder-decoder + # model uses seq_lens for dec / encoder_seq_lens for enc if (attn_type == AttentionType.ENCODER_DECODER): assert attn_metadata.seq_lens is not None assert attn_metadata.encoder_seq_lens is not None - # Default enc/dec cross-attention mask is non-causal + # Cross-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.seq_lens, attn_metadata.encoder_seq_lens) + + # Encoder branch of encoder-decoder model uses + # attn_metadata.encoder_seq_lens elif attn_type == AttentionType.ENCODER: + assert attn_metadata.encoder_seq_lens is not None - # Default encoder self-attention mask is non-causal + # Encoder self-attention mask is non-causal attn_bias = BlockDiagonalMask.from_seqlens( attn_metadata.encoder_seq_lens) - else: + + # Self-attention block of encoder-only model just + # uses the seq_lens directly. + elif attn_type == AttentionType.ENCODER_ONLY: assert attn_metadata.seq_lens is not None - # Default decoder self-attention mask is causal + # Encoder self-attention mask is non-causal + attn_bias = BlockDiagonalMask.from_seqlens( + attn_metadata.seq_lens) + + # Self-attention block of decoder branch just + # uses the seq_lens directly + elif attn_type == AttentionType.DECODER: + assert attn_metadata.seq_lens is not None + + # Decoder self-attention mask is causal attn_bias = BlockDiagonalCausalMask.from_seqlens( attn_metadata.seq_lens) + else: + raise ValueError("Unknown AttentionType: %s", attn_type) + if self.sliding_window is not None: attn_bias = attn_bias.make_local_attention( self.sliding_window) attn_bias = [attn_bias] else: + assert attn_type == AttentionType.DECODER assert attn_metadata.seq_lens is not None attn_bias = _make_alibi_bias(self.alibi_slopes, self.num_kv_heads, query.dtype, diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 76ccb3dfe0a65..3455a4ccf282f 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -12,6 +12,7 @@ class PoolingType(IntEnum): """Enumeration for different types of pooling methods.""" LAST = 0 ALL = 1 + CLS = 2 class Pooler(nn.Module): @@ -23,12 +24,13 @@ class Pooler(nn.Module): 3. Returns structured results as `PoolerOutput`. Attributes: - pooling_type: The type of pooling to use (LAST, AVERAGE, MAX). + pooling_type: The type of pooling to use (LAST, ALL, CLS). normalize: Whether to normalize the pooled data. """ def __init__(self, pooling_type: PoolingType, normalize: bool): super().__init__() + self.pooling_type = pooling_type self.normalize = normalize @@ -38,10 +40,16 @@ def forward( pooling_metadata: PoolingMetadata, ) -> PoolerOutput: """Pools specific information from hidden states based on metadata.""" + prompt_lens = PoolingTensors.from_pooling_metadata( pooling_metadata, hidden_states.device).prompt_lens - if self.pooling_type == PoolingType.LAST: + if self.pooling_type is PoolingType.CLS: + first_token_flat_indices = torch.zeros_like(prompt_lens) + first_token_flat_indices[1:] += torch.cumsum(prompt_lens, + dim=0)[:-1] + pooled_data = hidden_states[first_token_flat_indices] + elif self.pooling_type == PoolingType.LAST: last_token_flat_indices = torch.cumsum(prompt_lens, dim=0) - 1 pooled_data = hidden_states[last_token_flat_indices] elif self.pooling_type == PoolingType.ALL: diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py new file mode 100644 index 0000000000000..4c0a0e303e655 --- /dev/null +++ b/vllm/model_executor/models/bert.py @@ -0,0 +1,419 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import BertConfig + +from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention.backends.xformers import XFormersImpl +from vllm.config import CacheConfig +from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.vocab_parallel_embedding import ( + VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + + +class BertEmbedding(nn.Module): + + def __init__(self, config: BertConfig): + + super().__init__() + self.size = config.hidden_size + self.word_embeddings = VocabParallelEmbedding(config.vocab_size, + config.hidden_size) + self.position_embeddings = VocabParallelEmbedding( + config.max_position_embeddings, config.hidden_size) + self.token_type_embeddings = VocabParallelEmbedding( + config.type_vocab_size, config.hidden_size) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + self.position_ids = nn.Parameter( + torch.empty((1, config.max_position_embeddings)), ) + + self.position_embedding_type = config.position_embedding_type + if self.position_embedding_type != "absolute": + raise ValueError("Only 'absolute' position_embedding_type" + + " is supported") + + def forward( + self, + input_ids: torch.Tensor, + position_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + input_shape = input_ids.size() + + # Input embeddings. + inputs_embeds = self.word_embeddings(input_ids) + + # Position embeddings. + position_embeddings = self.position_embeddings(position_ids) + + # Token type embeddings. (TODO: move off hotpath?) + token_type_embeddings = self.token_type_embeddings( + torch.zeros(input_shape, + dtype=torch.long, + device=inputs_embeds.device)) + + embeddings = inputs_embeds + token_type_embeddings + position_embeddings + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class BertEncoder(nn.Module): + + def __init__(self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.layer = nn.ModuleList([ + BertLayer(config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.layer.{layer_idx}") + for layer_idx in range(config.num_hidden_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + for i in range(len(self.layer)): + layer = self.layer[i] + hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + self.attention = BertAttention( + hidden_size=config.hidden_size, + num_attention_heads=config.num_attention_heads, + layer_norm_eps=config.layer_norm_eps, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attention") + + self.intermediate = BertIntermediate( + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.intermediate") + + self.output = BertOutput(hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + layer_norm_eps=config.layer_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.output") + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: Optional[torch.Tensor], + attn_metadata: AttentionMetadata, + ): + attn_output = self.attention(hidden_states, kv_cache, attn_metadata) + intermediate_output = self.intermediate(attn_output) + output = self.output(intermediate_output, attn_output) + return output + + +class BertAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + layer_norm_eps: float, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + + self.self = BertSelfAttention(hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.output") + + self.output = BertSelfOutput(hidden_size=hidden_size, + layer_norm_eps=layer_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.output") + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + self_output = self.self(hidden_states, kv_cache, attn_metadata) + return self.output(self_output, hidden_states) + + +class BertSelfAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_attention_heads: int, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + + self.total_num_heads = num_attention_heads + assert self.total_num_heads % tp_size == 0 + + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = self.total_num_heads + self.head_dim = self.hidden_size // self.total_num_heads + assert self.head_dim * self.total_num_heads == self.hidden_size + + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.qkv_proj = QKVParallelLinear( + hidden_size=self.hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj") + + self.attn = Attention(num_heads=self.num_heads, + head_size=self.head_dim, + scale=self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.attn") + + if not isinstance(self.attn.impl, XFormersImpl): + raise ValueError( + "Encoder-only models currently require XFORMERS attention " + "backend. Set VLLM_ATTENTION_BACKEND=XFORMERS to use BERT.") + + def forward( + self, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.ENCODER_ONLY) + return output + + +class BertSelfOutput(nn.Module): + + def __init__(self, + hidden_size: int, + layer_norm_eps: float, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.dense = RowParallelLinear(input_size=hidden_size, + output_size=hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense") + self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertIntermediate(nn.Module): + + def __init__(self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.dense = ColumnParallelLinear(input_size=hidden_size, + output_size=intermediate_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense") + self.intermediate_act_fn = get_act_fn(hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, + hidden_size: int, + intermediate_size: int, + layer_norm_eps: float, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + + self.dense = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.dense") + + self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertModel(nn.Module): + + def __init__(self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): + super().__init__() + self.embeddings = BertEmbedding(config) + self.encoder = BertEncoder(config, + cache_config, + quant_config, + prefix=f"{prefix}.encoder") + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embeddings(input_ids=input_ids, + position_ids=position_ids) + + return self.encoder(hidden_states, kv_caches, attn_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "query", "q"), + ("qkv_proj", "key", "k"), + ("qkv_proj", "value", "v"), + ] + + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + if "pooler" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + +class BertEmbeddingModel(nn.Module): + """A model that uses Bert to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + config: BertConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.model = BertModel(config, cache_config, quant_config) + self._pooler = Pooler(pooling_type=PoolingType.CLS, normalize=True) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model(input_ids=input_ids, + position_ids=positions, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + attn_metadata=attn_metadata) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + self.model.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 03a67e3712d72..f442ce0f63e3e 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -87,6 +87,7 @@ _EMBEDDING_MODELS = { # [Text-only] + "BertModel": ("bert", "BertEmbeddingModel"), "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"), "MistralModel": ("llama", "LlamaEmbeddingModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"),