Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for embedding models: Text2Vec, M3E, GTE #3541

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/model_support.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ After these steps, the new model should be compatible with most FastChat feature
- example: `python3 -m fastchat.serve.cli --model-path meta-llama/Llama-2-7b-chat-hf`
- Vicuna, Alpaca, LLaMA, Koala
- example: `python3 -m fastchat.serve.cli --model-path lmsys/vicuna-7b-v1.5`
- [Alibaba-NLP/gte-large-en-v1.5](https://huggingface.co/Alibaba-NLP/gte-large-en-v1.5)
- [allenai/tulu-2-dpo-7b](https://huggingface.co/allenai/tulu-2-dpo-7b)
- [BAAI/AquilaChat-7B](https://huggingface.co/BAAI/AquilaChat-7B)
- [BAAI/AquilaChat2-7B](https://huggingface.co/BAAI/AquilaChat2-7B)
Expand Down Expand Up @@ -66,6 +67,7 @@ After these steps, the new model should be compatible with most FastChat feature
- [lmsys/fastchat-t5-3b-v1.0](https://huggingface.co/lmsys/fastchat-t5)
- [meta-math/MetaMath-7B-V1.0](https://huggingface.co/meta-math/MetaMath-7B-V1.0)
- [Microsoft/Orca-2-7b](https://huggingface.co/microsoft/Orca-2-7b)
- [moka-ai/m3e-large](https://huggingface.co/moka-ai/m3e-large)
- [mosaicml/mpt-7b-chat](https://huggingface.co/mosaicml/mpt-7b-chat)
- example: `python3 -m fastchat.serve.cli --model-path mosaicml/mpt-7b-chat`
- [Neutralzz/BiLLa-7B-SFT](https://huggingface.co/Neutralzz/BiLLa-7B-SFT)
Expand All @@ -81,6 +83,7 @@ After these steps, the new model should be compatible with most FastChat feature
- [Qwen/Qwen-7B-Chat](https://huggingface.co/Qwen/Qwen-7B-Chat)
- [rishiraj/CatPPT](https://huggingface.co/rishiraj/CatPPT)
- [Salesforce/codet5p-6b](https://huggingface.co/Salesforce/codet5p-6b)
- [shibing624/text2vec-base-multilingual](https://huggingface.co/shibing624/text2vec-base-multilingual)
- [StabilityAI/stablelm-tuned-alpha-7b](https://huggingface.co/stabilityai/stablelm-tuned-alpha-7b)
- [tenyx/TenyxChat-7B-v1](https://huggingface.co/tenyx/TenyxChat-7B-v1)
- [TinyLlama/TinyLlama-1.1B-Chat-v1.0](https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0)
Expand Down
101 changes: 59 additions & 42 deletions fastchat/model/model_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,34 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("one_shot")


class BaseEmbeddingModelAdapter(BaseModelAdapter):
"""The base embedding model adapter"""

use_fast_tokenizer = False

def match(self, model_path: str):
return "embedding" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
model = AutoModel.from_pretrained(
model_path,
**from_pretrained_kwargs,
)
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True, revision=revision
)
if hasattr(model.config, "max_position_embeddings") and hasattr(
tokenizer, "model_max_length"
):
model.config.max_sequence_length = min(
model.config.max_position_embeddings, tokenizer.model_max_length
)
model.use_cls_pooling = True
model.eval()
return model, tokenizer


# A global registry for all model adapters
# TODO (lmzheng): make it a priority queue.
model_adapters: List[BaseModelAdapter] = []
Expand Down Expand Up @@ -1836,64 +1864,49 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("qwen-7b-chat")


class BGEAdapter(BaseModelAdapter):
class BGEAdapter(BaseEmbeddingModelAdapter):
"""The model adapter for BGE (e.g., BAAI/bge-large-en-v1.5)"""

use_fast_tokenizer = False

def match(self, model_path: str):
return "bge" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
model = AutoModel.from_pretrained(
model_path,
**from_pretrained_kwargs,
)
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True, revision=revision
)
if hasattr(model.config, "max_position_embeddings") and hasattr(
tokenizer, "model_max_length"
):
model.config.max_sequence_length = min(
model.config.max_position_embeddings, tokenizer.model_max_length
)
model.use_cls_pooling = True
model.eval()
return model, tokenizer

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("one_shot")


class E5Adapter(BaseModelAdapter):
class E5Adapter(BaseEmbeddingModelAdapter):
"""The model adapter for E5 (e.g., intfloat/e5-large-v2)"""

use_fast_tokenizer = False

def match(self, model_path: str):
return "e5-" in model_path.lower()

def load_model(self, model_path: str, from_pretrained_kwargs: dict):
revision = from_pretrained_kwargs.get("revision", "main")
model = AutoModel.from_pretrained(
model_path,
**from_pretrained_kwargs,
)
tokenizer = AutoTokenizer.from_pretrained(
model_path, trust_remote_code=True, revision=revision
)
if hasattr(model.config, "max_position_embeddings") and hasattr(
tokenizer, "model_max_length"
):
model.config.max_sequence_length = min(
model.config.max_position_embeddings, tokenizer.model_max_length
)
return model, tokenizer

def get_default_conv_template(self, model_path: str) -> Conversation:
return get_conv_template("one_shot")
class Text2VecAdapter(BaseEmbeddingModelAdapter):
"""The model adapter for text2vec (e.g., shibing624/text2vec-base-chinese)"""

use_fast_tokenizer = False

def match(self, model_path: str):
return "text2vec" in model_path.lower()


class M3EAdapter(BaseEmbeddingModelAdapter):
"""The model adapter for m3e (e.g., moka-ai/m3e-large)"""

use_fast_tokenizer = False

def match(self, model_path: str):
return "m3e-" in model_path.lower()


class GTEAdapter(BaseEmbeddingModelAdapter):
"""The model adapter for gte (e.g., Alibaba-NLP/gte-large-en-v1.5)"""

use_fast_tokenizer = False

def match(self, model_path: str):
return "gte-" in model_path.lower()


class AquilaChatAdapter(BaseModelAdapter):
Expand Down Expand Up @@ -2562,6 +2575,9 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(AquilaChatAdapter)
register_model_adapter(BGEAdapter)
register_model_adapter(E5Adapter)
register_model_adapter(Text2VecAdapter)
register_model_adapter(M3EAdapter)
register_model_adapter(GTEAdapter)
register_model_adapter(Lamma2ChineseAdapter)
register_model_adapter(Lamma2ChineseAlpacaAdapter)
register_model_adapter(VigogneAdapter)
Expand Down Expand Up @@ -2603,5 +2619,6 @@ def get_default_conv_template(self, model_path: str) -> Conversation:
register_model_adapter(GrokAdapter)
register_model_adapter(NoSystemAdapter)

register_model_adapter(BaseEmbeddingModelAdapter)
# After all adapters, try the default base adapter.
register_model_adapter(BaseModelAdapter)
Loading