Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch.compile] support encoder based models #10613

Merged
merged 3 commits into from
Nov 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions tests/compile/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,16 @@ class TestSetting:
method="encode",
fullgraph=True,
),
# encoder-based embedding model (BERT)
TestSetting(
model="BAAI/bge-base-en-v1.5",
model_args=["--task", "embedding"],
pp_size=1,
tp_size=1,
attn_backend="XFORMERS",
method="encode",
fullgraph=True,
),
# vision language model
TestSetting(
model="microsoft/Phi-3.5-vision-instruct",
Expand Down
17 changes: 7 additions & 10 deletions vllm/model_executor/models/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from transformers import BertConfig

from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, PoolerConfig, VllmConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
Expand Down Expand Up @@ -92,14 +93,14 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return pooled_output


@support_torch_compile
class BertEncoder(nn.Module):

def __init__(self,
config: BertConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = ""):
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.layer = nn.ModuleList([
BertLayer(config=config,
cache_config=cache_config,
Expand Down Expand Up @@ -336,12 +337,8 @@ def __init__(self,
add_pooling_layer: bool = False):
super().__init__()
config = vllm_config.model_config.hf_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
self.embeddings = embedding_class(config)
self.encoder = BertEncoder(config,
cache_config,
quant_config,
self.encoder = BertEncoder(vllm_config=vllm_config,
prefix=f"{prefix}.encoder")
self.pooler = BertPooler(config) if add_pooling_layer else None

Expand Down