diff --git a/weave/scorers/llm_utils.py b/weave/scorers/llm_utils.py index b95e88c6281c..735a5c9baa8f 100644 --- a/weave/scorers/llm_utils.py +++ b/weave/scorers/llm_utils.py @@ -166,6 +166,7 @@ def download_model(model_name_or_path: str, local_dir: str = "weave_models") -> "toxicity_scorer": "c-metrics/weave-scorers/toxicity_scorer:v0", "bias_scorer": "c-metrics/weave-scorers/bias_scorer:v0", "relevance_scorer": "c-metrics/context-relevance-scorer/relevance_scorer:v0", + "robustness_scorer": "c-metrics/weave-scorers/robustness_scorer_embedding_model:v0", "llamaguard": "c-metrics/weave-scorers/llamaguard:v0", } diff --git a/weave/scorers/robustness_scorer.py b/weave/scorers/robustness_scorer.py index ce4ac533a36a..b39c8dca42f8 100644 --- a/weave/scorers/robustness_scorer.py +++ b/weave/scorers/robustness_scorer.py @@ -1,15 +1,18 @@ +import os import math import random import string +from importlib.util import find_spec from typing import Any, Optional, Union import numpy as np import weave -from weave.scorers.base_scorer import Scorer +from weave.scorers.llm_scorer import HuggingFaceScorer +from weave.scorers.llm_utils import MODEL_PATHS, download_model -class RobustnessScorer(Scorer): +class RobustnessScorer(HuggingFaceScorer): """ RobustnessScorer evaluates the robustness of a language model's outputs against input perturbations. @@ -50,31 +53,32 @@ class RobustnessScorer(Scorer): use_exact_match: bool = True use_ground_truths: bool = False return_interpretation: bool = True - embedding_model_name: str = "all-MiniLM-L6-v2" + similarity_metric: str = "cosine" embedding_model: Optional[Any] = ( None # Delay type hinting to avoid dependency on SentenceTransformer ) cohen_d_threshold: float = 1e-2 - def model_post_init(self, __context: Any) -> None: - """ - Post-initialization method to load the embedding model if required. - - Args: - __context (Any): Contextual information (not used in this implementation). - """ - # Load an embedding model for semantic similarity scoring - if not self.use_exact_match: - try: - from sentence_transformers import SentenceTransformer - except ImportError as e: - raise ImportError( - "The `SentenceTransformer` and `torch` packages are required to use `RobustnessScorer` with semantic similarity scoring. (`use_exact_match=False`)" - "Please install them by running `pip install sentence-transformers torch`." - ) from e + def load_model(self) -> None: + try: + if find_spec("sentence_transformers") is None: + raise ImportError("sentence_transformers is required but not installed") + from sentence_transformers import SentenceTransformer + except ImportError: + print( + "The `sentence_transformers` package is required to use the RobustnessScorer, please run `pip install sentence-transformers`" + ) + """Initialize the model, tokenizer and device after pydantic initialization.""" + if os.path.isdir(self.model_name_or_path): + self._local_model_path = self.model_name_or_path + elif self.model_name_or_path != "": + self._local_model_path = download_model(self.model_name_or_path) + else: + self._local_model_path = download_model(MODEL_PATHS["robustness_scorer"]) + assert self._local_model_path, "model_name_or_path local path or artifact path not found" - self.embedding_model = SentenceTransformer(self.embedding_model_name) + self.embedding_model = SentenceTransformer(self._local_model_path) @weave.op def score(