From cc67b877fd48637e841e3f3191c2bcf44aa8771c Mon Sep 17 00:00:00 2001 From: jiaolab-tianhao Date: Tue, 9 Jul 2024 07:43:21 +0000 Subject: [PATCH 1/3] precompute collapsed linear transformation --- .../routers/matrix_factorization/model.py | 35 +++++++++++++++---- routellm/routers/routers.py | 1 + 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/routellm/routers/matrix_factorization/model.py b/routellm/routers/matrix_factorization/model.py index 09fbb25..c6e1551 100644 --- a/routellm/routers/matrix_factorization/model.py +++ b/routellm/routers/matrix_factorization/model.py @@ -79,10 +79,11 @@ def __init__( text_dim, num_classes, use_proj, + collapse_linear=False, ): super().__init__() - self._name = "TextMF" self.use_proj = use_proj + self.collapse_linear = collapse_linear # collapse the linear transformations into a single linear layer self.P = torch.nn.Embedding(num_models, dim) self.embedding_model = "text-embedding-3-small" @@ -104,19 +105,20 @@ def get_device(self): return self.P.weight.device def forward(self, model_id, prompt): - model_id = torch.tensor(model_id, dtype=torch.long).to(self.get_device()) - - model_embed = self.P(model_id) - model_embed = torch.nn.functional.normalize(model_embed, p=2, dim=1) - prompt_embed = ( OPENAI_CLIENT.embeddings.create(input=[prompt], model=self.embedding_model) .data[0] .embedding ) prompt_embed = torch.tensor(prompt_embed, device=self.get_device()) - prompt_embed = self.text_proj(prompt_embed) + model_id = torch.tensor(model_id, dtype=torch.long).to(self.get_device()) + + if self.collapse_linear: + upscaled_model_embed = self.precompute_upscaled_embedding(model_id) + return upscaled_model_embed @ prompt_embed.squeeze(-1) + model_embed = self.P(model_id) + prompt_embed = self.text_proj(prompt_embed) return self.classifier(model_embed * prompt_embed).squeeze() @torch.no_grad() @@ -127,3 +129,22 @@ def pred_win_rate(self, model_a, model_b, prompt): def load(self, path): self.load_state_dict(torch.load(path)) + + def post_process_weight(self): + # since the current model consist of only linear transformations + # we can collapse the linear transformations into a single linear layer + # https://github.com/lm-sys/RouteLLM/issues/9 + num_models = self.P.weight.shape[0] + text_dim = self.text_proj[0].weight.shape[1] + + self.P.weight.data = torch.nn.functional.normalize( + self.P.weight.data, p=2, dim=1 + ) + + if self.collapse_linear: + self.precompute_upscaled_embedding = torch.nn.Embedding( + num_models, text_dim + ) + self.precompute_upscaled_embedding.weight.data = ( + self.P.weight * self.classifier[0].weight.data + ) @ self.text_proj[0].weight.data diff --git a/routellm/routers/routers.py b/routellm/routers/routers.py index 0096c0a..a1efeb3 100644 --- a/routellm/routers/routers.py +++ b/routellm/routers/routers.py @@ -231,6 +231,7 @@ def __init__( num_classes=num_classes, use_proj=use_proj, ) + self.model.post_process_weight() self.model = self.model.eval().to(device) self.strong_model_id = MODEL_IDS[strong_model] self.weak_model_id = MODEL_IDS[weak_model] From 2a4328d14790ff6187159e1ae02d943ebef8fe1b Mon Sep 17 00:00:00 2001 From: jiaolab-tianhao Date: Sun, 14 Jul 2024 01:06:26 +0000 Subject: [PATCH 2/3] support local embedding model --- .../routers/matrix_factorization/model.py | 59 +++++++++++++++---- routellm/routers/routers.py | 5 +- 2 files changed, 52 insertions(+), 12 deletions(-) diff --git a/routellm/routers/matrix_factorization/model.py b/routellm/routers/matrix_factorization/model.py index c6e1551..1a72c6c 100644 --- a/routellm/routers/matrix_factorization/model.py +++ b/routellm/routers/matrix_factorization/model.py @@ -74,19 +74,45 @@ class MFModel(torch.nn.Module, PyTorchModelHubMixin): def __init__( self, - dim, - num_models, - text_dim, - num_classes, - use_proj, + dim=128, + num_models=64, + text_dim=768, + num_classes=1, + use_proj=True, collapse_linear=False, + embedding_model="all-mpnet-base-v2", ): + """ + Args: + dim: + Dimension of the model embeddings, default to 128 + num_models: + Number of models, default to 64 + text_dim: + Dimension of the text embeddings + 1536 for OpenAI's text-embedding-3-small + 768 for all-mpnet-base-v2 + num_classes: + Number of classes, default to 1, output a scalar + use_proj: + Whether to use projection for the text embeddings + This is set to be True in our pretrained models for better performance + collapse_linear: + Whether to collapse the linear transformations into a single linear layer + Since the current pretrained models only consist of Linear layers, + we can collapse them into a single layer for faster inference + See https://github.com/lm-sys/RouteLLM/issues/9 + embedding_model: + Text embedding model for the prompt, should be the same as the one used in training + Use all-mpnet-base-v2 to avoid OpenAI's key, however, slightly worse performance + Use OpenAI's text-embedding-3-small for better performance + """ super().__init__() self.use_proj = use_proj self.collapse_linear = collapse_linear # collapse the linear transformations into a single linear layer self.P = torch.nn.Embedding(num_models, dim) - self.embedding_model = "text-embedding-3-small" + self.embedding_model = embedding_model if self.use_proj: self.text_proj = torch.nn.Sequential( @@ -105,11 +131,17 @@ def get_device(self): return self.P.weight.device def forward(self, model_id, prompt): - prompt_embed = ( - OPENAI_CLIENT.embeddings.create(input=[prompt], model=self.embedding_model) - .data[0] - .embedding - ) + if self.embedding_model == "text-embedding-3-small": + prompt_embed = ( + OPENAI_CLIENT.embeddings.create( + input=[prompt], model=self.embedding_model + ) + .data[0] + .embedding + ) + elif self.embedding_model == "all-mpnet-base-v2": + prompt_embed = self._embedding_model.encode([prompt]) + prompt_embed = torch.tensor(prompt_embed, device=self.get_device()) model_id = torch.tensor(model_id, dtype=torch.long).to(self.get_device()) @@ -141,6 +173,11 @@ def post_process_weight(self): self.P.weight.data, p=2, dim=1 ) + if self.embedding_model == "all-mpnet-base-v2": + from sentence_transformers import SentenceTransformer + + self._embedding_model = SentenceTransformer(self.embedding_model) + if self.collapse_linear: self.precompute_upscaled_embedding = torch.nn.Embedding( num_models, text_dim diff --git a/routellm/routers/routers.py b/routellm/routers/routers.py index a1efeb3..2212bf2 100644 --- a/routellm/routers/routers.py +++ b/routellm/routers/routers.py @@ -217,9 +217,10 @@ def __init__( weak_model="mixtral-8x7b-instruct-v0.1", hidden_size=128, num_models=64, - text_dim=1536, + text_dim=768, num_classes=1, use_proj=True, + embedding_model="all-mpnet-base-v2", ): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -230,7 +231,9 @@ def __init__( text_dim=text_dim, num_classes=num_classes, use_proj=use_proj, + embedding_model=embedding_model, ) + self.model.post_process_weight() self.model = self.model.eval().to(device) self.strong_model_id = MODEL_IDS[strong_model] From 1316690e48715d6e7dd26344b3a1c2584ef533b8 Mon Sep 17 00:00:00 2001 From: jiaolab-tianhao Date: Sun, 14 Jul 2024 05:30:48 +0000 Subject: [PATCH 3/3] add stella_en_400M_v5 support --- .../routers/matrix_factorization/model.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/routellm/routers/matrix_factorization/model.py b/routellm/routers/matrix_factorization/model.py index 1a72c6c..9e667f9 100644 --- a/routellm/routers/matrix_factorization/model.py +++ b/routellm/routers/matrix_factorization/model.py @@ -92,6 +92,7 @@ def __init__( Dimension of the text embeddings 1536 for OpenAI's text-embedding-3-small 768 for all-mpnet-base-v2 + 1024 for infgrad/stella_en_400M_v5 num_classes: Number of classes, default to 1, output a scalar use_proj: @@ -141,6 +142,15 @@ def forward(self, model_id, prompt): ) elif self.embedding_model == "all-mpnet-base-v2": prompt_embed = self._embedding_model.encode([prompt]) + elif self.embedding_model == "infgrad/stella_en_400M_v5": + prompt_embed = self._embedding_model.encode( + [prompt], prompt_name="s2s_query" + ) + else: + raise ValueError( + f"Unsupported embedding model {self.embedding_model}, " + "should be one of text-embedding-3-small, all-mpnet-base-v2, infgrad/stella_en_400M_v5" + ) prompt_embed = torch.tensor(prompt_embed, device=self.get_device()) model_id = torch.tensor(model_id, dtype=torch.long).to(self.get_device()) @@ -173,10 +183,15 @@ def post_process_weight(self): self.P.weight.data, p=2, dim=1 ) - if self.embedding_model == "all-mpnet-base-v2": + if ( + self.embedding_model == "all-mpnet-base-v2" + or self.embedding_model == "infgrad/stella_en_400M_v5" + ): from sentence_transformers import SentenceTransformer - self._embedding_model = SentenceTransformer(self.embedding_model) + self._embedding_model = SentenceTransformer( + self.embedding_model, trust_remote_code=True + ).to("cuda") if self.collapse_linear: self.precompute_upscaled_embedding = torch.nn.Embedding(