From d71718b1bb6b0fc0cf378cea3b16528091fdd8d7 Mon Sep 17 00:00:00 2001 From: Roman Solomatin <36135455+Samoed@users.noreply.github.com> Date: Fri, 10 Jan 2025 18:28:53 +0300 Subject: [PATCH 1/7] add dotwrapper --- mteb/models/sentence_transformer_wrapper.py | 10 +++++++++- mteb/models/sentence_transformers_models.py | 9 +++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/mteb/models/sentence_transformer_wrapper.py b/mteb/models/sentence_transformer_wrapper.py index e580ef8959..2b9da8105b 100644 --- a/mteb/models/sentence_transformer_wrapper.py +++ b/mteb/models/sentence_transformer_wrapper.py @@ -11,6 +11,7 @@ from mteb.encoder_interface import PromptType from .wrapper import Wrapper +from ..evaluation import dot_distance logger = logging.getLogger(__name__) @@ -21,6 +22,7 @@ def __init__( model: str | SentenceTransformer | CrossEncoder, revision: str | None = None, model_prompts: dict[str, str] | None = None, + use_model_similarity: bool = True, **kwargs, ) -> None: """Wrapper for SentenceTransformer models. @@ -32,6 +34,7 @@ def __init__( First priority is given to the composed prompt of task name + prompt type (query or passage), then to the specific task prompt, then to the composed prompt of task type + prompt type, then to the specific task type prompt, and finally to the specific prompt type. + use_model_similarity: Whether to use the model's similarity method. **kwargs: Additional arguments to pass to the SentenceTransformer model. """ if isinstance(model, str): @@ -59,7 +62,7 @@ def __init__( if isinstance(self.model, CrossEncoder): self.predict = self._predict - if hasattr(self.model, "similarity"): + if hasattr(self.model, "similarity") and use_model_similarity: self.similarity = self.model.similarity def encode( @@ -125,3 +128,8 @@ def _predict( convert_to_numpy=True, **kwargs, ) + + +class SentenceTransformerWrapperDotSimilarity(SentenceTransformerWrapper): + def similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float: + return dot_distance(embedding1, embedding2) diff --git a/mteb/models/sentence_transformers_models.py b/mteb/models/sentence_transformers_models.py index 4e0ad6420f..32ce57bfc7 100644 --- a/mteb/models/sentence_transformers_models.py +++ b/mteb/models/sentence_transformers_models.py @@ -2,7 +2,10 @@ from __future__ import annotations +from functools import partial + from mteb.model_meta import ModelMeta +from mteb.models.sentence_transformer_wrapper import SentenceTransformerWrapperDotSimilarity paraphrase_langs = [ "ara_Arab", @@ -375,6 +378,12 @@ ) contriever = ModelMeta( + loader=partial( + SentenceTransformerWrapperDotSimilarity, + model="facebook/contriever-msmarco", + revision="abe8c1493371369031bcb1e02acb754cf4e162fa", + use_model_similarity=False, + ), name="facebook/contriever-msmarco", languages=["eng-Latn"], open_weights=True, From d50fd88796cb1df2b481ccd3a6701e720780ca06 Mon Sep 17 00:00:00 2001 From: Roman Solomatin <36135455+Samoed@users.noreply.github.com> Date: Fri, 10 Jan 2025 18:29:51 +0300 Subject: [PATCH 2/7] lint --- mteb/models/sentence_transformer_wrapper.py | 2 +- mteb/models/sentence_transformers_models.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mteb/models/sentence_transformer_wrapper.py b/mteb/models/sentence_transformer_wrapper.py index 2b9da8105b..89888a892e 100644 --- a/mteb/models/sentence_transformer_wrapper.py +++ b/mteb/models/sentence_transformer_wrapper.py @@ -10,8 +10,8 @@ from mteb.encoder_interface import PromptType -from .wrapper import Wrapper from ..evaluation import dot_distance +from .wrapper import Wrapper logger = logging.getLogger(__name__) diff --git a/mteb/models/sentence_transformers_models.py b/mteb/models/sentence_transformers_models.py index 32ce57bfc7..c9b6ece897 100644 --- a/mteb/models/sentence_transformers_models.py +++ b/mteb/models/sentence_transformers_models.py @@ -5,7 +5,9 @@ from functools import partial from mteb.model_meta import ModelMeta -from mteb.models.sentence_transformer_wrapper import SentenceTransformerWrapperDotSimilarity +from mteb.models.sentence_transformer_wrapper import ( + SentenceTransformerWrapperDotSimilarity, +) paraphrase_langs = [ "ara_Arab", From 7d1e949f555066b08134ccacd89690e92554af30 Mon Sep 17 00:00:00 2001 From: Roman Solomatin <36135455+Samoed@users.noreply.github.com> Date: Fri, 10 Jan 2025 18:44:37 +0300 Subject: [PATCH 3/7] make cleaner --- mteb/models/sentence_transformer_wrapper.py | 4 +--- mteb/models/sentence_transformers_models.py | 1 - 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/mteb/models/sentence_transformer_wrapper.py b/mteb/models/sentence_transformer_wrapper.py index 89888a892e..f6ec94235d 100644 --- a/mteb/models/sentence_transformer_wrapper.py +++ b/mteb/models/sentence_transformer_wrapper.py @@ -22,7 +22,6 @@ def __init__( model: str | SentenceTransformer | CrossEncoder, revision: str | None = None, model_prompts: dict[str, str] | None = None, - use_model_similarity: bool = True, **kwargs, ) -> None: """Wrapper for SentenceTransformer models. @@ -34,7 +33,6 @@ def __init__( First priority is given to the composed prompt of task name + prompt type (query or passage), then to the specific task prompt, then to the composed prompt of task type + prompt type, then to the specific task type prompt, and finally to the specific prompt type. - use_model_similarity: Whether to use the model's similarity method. **kwargs: Additional arguments to pass to the SentenceTransformer model. """ if isinstance(model, str): @@ -62,7 +60,7 @@ def __init__( if isinstance(self.model, CrossEncoder): self.predict = self._predict - if hasattr(self.model, "similarity") and use_model_similarity: + if hasattr(self.model, "similarity") and not hasattr(self, "similarity"): self.similarity = self.model.similarity def encode( diff --git a/mteb/models/sentence_transformers_models.py b/mteb/models/sentence_transformers_models.py index c9b6ece897..cb9a312b3c 100644 --- a/mteb/models/sentence_transformers_models.py +++ b/mteb/models/sentence_transformers_models.py @@ -384,7 +384,6 @@ SentenceTransformerWrapperDotSimilarity, model="facebook/contriever-msmarco", revision="abe8c1493371369031bcb1e02acb754cf4e162fa", - use_model_similarity=False, ), name="facebook/contriever-msmarco", languages=["eng-Latn"], From f99786dc1a87c2674e15cd874a95b14d3ee38ddf Mon Sep 17 00:00:00 2001 From: Roman Solomatin <36135455+Samoed@users.noreply.github.com> Date: Sat, 11 Jan 2025 22:57:00 +0300 Subject: [PATCH 4/7] add similarity_fn --- mteb/models/sentence_transformer_wrapper.py | 15 +++++++-------- mteb/models/sentence_transformers_models.py | 6 ++++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/mteb/models/sentence_transformer_wrapper.py b/mteb/models/sentence_transformer_wrapper.py index f6ec94235d..eeac485dc7 100644 --- a/mteb/models/sentence_transformer_wrapper.py +++ b/mteb/models/sentence_transformer_wrapper.py @@ -2,7 +2,7 @@ import logging from collections.abc import Sequence -from typing import Any +from typing import Any, Callable import numpy as np import torch @@ -10,7 +10,6 @@ from mteb.encoder_interface import PromptType -from ..evaluation import dot_distance from .wrapper import Wrapper logger = logging.getLogger(__name__) @@ -22,6 +21,7 @@ def __init__( model: str | SentenceTransformer | CrossEncoder, revision: str | None = None, model_prompts: dict[str, str] | None = None, + similarity_fn: Callable[[np.ndarray, np.ndarray], np.ndarray] | None = None, **kwargs, ) -> None: """Wrapper for SentenceTransformer models. @@ -33,6 +33,7 @@ def __init__( First priority is given to the composed prompt of task name + prompt type (query or passage), then to the specific task prompt, then to the composed prompt of task type + prompt type, then to the specific task type prompt, and finally to the specific prompt type. + similarity_fn: A similarity function to use. **kwargs: Additional arguments to pass to the SentenceTransformer model. """ if isinstance(model, str): @@ -60,7 +61,10 @@ def __init__( if isinstance(self.model, CrossEncoder): self.predict = self._predict - if hasattr(self.model, "similarity") and not hasattr(self, "similarity"): + if similarity_fn and callable(similarity_fn): + self.similarity = similarity_fn + + if hasattr(self.model, "similarity") and similarity_fn is None: self.similarity = self.model.similarity def encode( @@ -126,8 +130,3 @@ def _predict( convert_to_numpy=True, **kwargs, ) - - -class SentenceTransformerWrapperDotSimilarity(SentenceTransformerWrapper): - def similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float: - return dot_distance(embedding1, embedding2) diff --git a/mteb/models/sentence_transformers_models.py b/mteb/models/sentence_transformers_models.py index cb9a312b3c..782d698c63 100644 --- a/mteb/models/sentence_transformers_models.py +++ b/mteb/models/sentence_transformers_models.py @@ -4,9 +4,10 @@ from functools import partial +from mteb.evaluation import dot_distance from mteb.model_meta import ModelMeta from mteb.models.sentence_transformer_wrapper import ( - SentenceTransformerWrapperDotSimilarity, + SentenceTransformerWrapper, ) paraphrase_langs = [ @@ -381,9 +382,10 @@ contriever = ModelMeta( loader=partial( - SentenceTransformerWrapperDotSimilarity, + SentenceTransformerWrapper, model="facebook/contriever-msmarco", revision="abe8c1493371369031bcb1e02acb754cf4e162fa", + similarity_fn=dot_distance, ), name="facebook/contriever-msmarco", languages=["eng-Latn"], From d13fe99d52cd34c904b1cfde834714994f84d771 Mon Sep 17 00:00:00 2001 From: Roman Solomatin <36135455+Samoed@users.noreply.github.com> Date: Sat, 11 Jan 2025 23:13:04 +0300 Subject: [PATCH 5/7] update to similarity_fn_name --- mteb/models/sentence_transformer_wrapper.py | 11 +++++------ mteb/models/sentence_transformers_models.py | 2 +- mteb/models/wrapper.py | 11 +++++++++++ 3 files changed, 17 insertions(+), 7 deletions(-) diff --git a/mteb/models/sentence_transformer_wrapper.py b/mteb/models/sentence_transformer_wrapper.py index eeac485dc7..1c70f50fb6 100644 --- a/mteb/models/sentence_transformer_wrapper.py +++ b/mteb/models/sentence_transformer_wrapper.py @@ -21,7 +21,7 @@ def __init__( model: str | SentenceTransformer | CrossEncoder, revision: str | None = None, model_prompts: dict[str, str] | None = None, - similarity_fn: Callable[[np.ndarray, np.ndarray], np.ndarray] | None = None, + similarity_fn_name: str | None = None, **kwargs, ) -> None: """Wrapper for SentenceTransformer models. @@ -33,7 +33,7 @@ def __init__( First priority is given to the composed prompt of task name + prompt type (query or passage), then to the specific task prompt, then to the composed prompt of task type + prompt type, then to the specific task type prompt, and finally to the specific prompt type. - similarity_fn: A similarity function to use. + similarity_fn_name: A similarity function to use. **kwargs: Additional arguments to pass to the SentenceTransformer model. """ if isinstance(model, str): @@ -61,10 +61,9 @@ def __init__( if isinstance(self.model, CrossEncoder): self.predict = self._predict - if similarity_fn and callable(similarity_fn): - self.similarity = similarity_fn - - if hasattr(self.model, "similarity") and similarity_fn is None: + if similarity_fn_name: + self.similarity = self.get_similarity_function(similarity_fn_name) + elif hasattr(self.model, "similarity"): self.similarity = self.model.similarity def encode( diff --git a/mteb/models/sentence_transformers_models.py b/mteb/models/sentence_transformers_models.py index 782d698c63..af19be1f24 100644 --- a/mteb/models/sentence_transformers_models.py +++ b/mteb/models/sentence_transformers_models.py @@ -385,7 +385,7 @@ SentenceTransformerWrapper, model="facebook/contriever-msmarco", revision="abe8c1493371369031bcb1e02acb754cf4e162fa", - similarity_fn=dot_distance, + similarity_fn="dot", ), name="facebook/contriever-msmarco", languages=["eng-Latn"], diff --git a/mteb/models/wrapper.py b/mteb/models/wrapper.py index 956071d3dc..7d6d8f66e8 100644 --- a/mteb/models/wrapper.py +++ b/mteb/models/wrapper.py @@ -3,9 +3,12 @@ import logging from typing import Callable, get_args +import numpy as np + import mteb from mteb.abstasks.TaskMetadata import TASK_TYPE from mteb.encoder_interface import PromptType +from mteb.evaluation.evaluators.utils import cos_sim, dot_score logger = logging.getLogger(__name__) @@ -64,6 +67,14 @@ def get_prompt_name( ) return None + @staticmethod + def get_similarity_function(similarity_fn_name: str) -> Callable[[np.ndarray, np.ndarray], np.ndarray]: + if similarity_fn_name == "cosine": + return cos_sim + if similarity_fn_name == "dot": + return dot_score + raise ValueError("Invalid similarity function. Should be one of ['cosine', 'dot']") + @staticmethod def validate_task_to_prompt_name( task_to_prompt_name: dict[str, str] | None, From 3cf2168460d39d49d45a2a410fd34ed5541b6b9d Mon Sep 17 00:00:00 2001 From: Roman Solomatin <36135455+Samoed@users.noreply.github.com> Date: Sat, 11 Jan 2025 23:17:12 +0300 Subject: [PATCH 6/7] lint --- mteb/models/sentence_transformer_wrapper.py | 2 +- mteb/models/sentence_transformers_models.py | 1 - mteb/models/wrapper.py | 8 ++++++-- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/mteb/models/sentence_transformer_wrapper.py b/mteb/models/sentence_transformer_wrapper.py index 1c70f50fb6..8c133e125c 100644 --- a/mteb/models/sentence_transformer_wrapper.py +++ b/mteb/models/sentence_transformer_wrapper.py @@ -2,7 +2,7 @@ import logging from collections.abc import Sequence -from typing import Any, Callable +from typing import Any import numpy as np import torch diff --git a/mteb/models/sentence_transformers_models.py b/mteb/models/sentence_transformers_models.py index af19be1f24..c1f3fba4c0 100644 --- a/mteb/models/sentence_transformers_models.py +++ b/mteb/models/sentence_transformers_models.py @@ -4,7 +4,6 @@ from functools import partial -from mteb.evaluation import dot_distance from mteb.model_meta import ModelMeta from mteb.models.sentence_transformer_wrapper import ( SentenceTransformerWrapper, diff --git a/mteb/models/wrapper.py b/mteb/models/wrapper.py index 7d6d8f66e8..76b31ba529 100644 --- a/mteb/models/wrapper.py +++ b/mteb/models/wrapper.py @@ -68,12 +68,16 @@ def get_prompt_name( return None @staticmethod - def get_similarity_function(similarity_fn_name: str) -> Callable[[np.ndarray, np.ndarray], np.ndarray]: + def get_similarity_function( + similarity_fn_name: str, + ) -> Callable[[np.ndarray, np.ndarray], np.ndarray]: if similarity_fn_name == "cosine": return cos_sim if similarity_fn_name == "dot": return dot_score - raise ValueError("Invalid similarity function. Should be one of ['cosine', 'dot']") + raise ValueError( + "Invalid similarity function. Should be one of ['cosine', 'dot']" + ) @staticmethod def validate_task_to_prompt_name( From fabe9fb19a46963f15a890040116f157caea555f Mon Sep 17 00:00:00 2001 From: Roman Solomatin <36135455+Samoed@users.noreply.github.com> Date: Sat, 11 Jan 2025 23:17:35 +0300 Subject: [PATCH 7/7] fix name parameter --- mteb/models/sentence_transformers_models.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mteb/models/sentence_transformers_models.py b/mteb/models/sentence_transformers_models.py index c1f3fba4c0..1d3a0c5f7e 100644 --- a/mteb/models/sentence_transformers_models.py +++ b/mteb/models/sentence_transformers_models.py @@ -384,7 +384,7 @@ SentenceTransformerWrapper, model="facebook/contriever-msmarco", revision="abe8c1493371369031bcb1e02acb754cf4e162fa", - similarity_fn="dot", + similarity_fn_name="dot", ), name="facebook/contriever-msmarco", languages=["eng-Latn"],