Skip to content

Commit

Permalink
Upload embedding model weights for RobustnessScorer
Browse files Browse the repository at this point in the history
  • Loading branch information
morganmcg1 committed Feb 5, 2025
1 parent a2a8cf4 commit c6e981b
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 20 deletions.
1 change: 1 addition & 0 deletions weave/scorers/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def download_model(model_name_or_path: str, local_dir: str = "weave_models") ->
"toxicity_scorer": "c-metrics/weave-scorers/toxicity_scorer:v0",
"bias_scorer": "c-metrics/weave-scorers/bias_scorer:v0",
"relevance_scorer": "c-metrics/context-relevance-scorer/relevance_scorer:v0",
"robustness_scorer": "c-metrics/weave-scorers/robustness_scorer_embedding_model:v0",
"llamaguard": "c-metrics/weave-scorers/llamaguard:v0",
}

Expand Down
44 changes: 24 additions & 20 deletions weave/scorers/robustness_scorer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import os
import math
import random
import string
from importlib.util import find_spec
from typing import Any, Optional, Union

import numpy as np

import weave
from weave.scorers.base_scorer import Scorer
from weave.scorers.llm_scorer import HuggingFaceScorer
from weave.scorers.llm_utils import MODEL_PATHS, download_model


class RobustnessScorer(Scorer):
class RobustnessScorer(HuggingFaceScorer):
"""
RobustnessScorer evaluates the robustness of a language model's outputs against input perturbations.
Expand Down Expand Up @@ -50,31 +53,32 @@ class RobustnessScorer(Scorer):
use_exact_match: bool = True
use_ground_truths: bool = False
return_interpretation: bool = True
embedding_model_name: str = "all-MiniLM-L6-v2"

similarity_metric: str = "cosine"
embedding_model: Optional[Any] = (
None # Delay type hinting to avoid dependency on SentenceTransformer
)
cohen_d_threshold: float = 1e-2

def model_post_init(self, __context: Any) -> None:
"""
Post-initialization method to load the embedding model if required.
Args:
__context (Any): Contextual information (not used in this implementation).
"""
# Load an embedding model for semantic similarity scoring
if not self.use_exact_match:
try:
from sentence_transformers import SentenceTransformer
except ImportError as e:
raise ImportError(
"The `SentenceTransformer` and `torch` packages are required to use `RobustnessScorer` with semantic similarity scoring. (`use_exact_match=False`)"
"Please install them by running `pip install sentence-transformers torch`."
) from e
def load_model(self) -> None:
try:
if find_spec("sentence_transformers") is None:
raise ImportError("sentence_transformers is required but not installed")
from sentence_transformers import SentenceTransformer
except ImportError:
print(
"The `sentence_transformers` package is required to use the RobustnessScorer, please run `pip install sentence-transformers`"
)
"""Initialize the model, tokenizer and device after pydantic initialization."""
if os.path.isdir(self.model_name_or_path):
self._local_model_path = self.model_name_or_path
elif self.model_name_or_path != "":
self._local_model_path = download_model(self.model_name_or_path)
else:
self._local_model_path = download_model(MODEL_PATHS["robustness_scorer"])
assert self._local_model_path, "model_name_or_path local path or artifact path not found"

self.embedding_model = SentenceTransformer(self.embedding_model_name)
self.embedding_model = SentenceTransformer(self._local_model_path)

@weave.op
def score(
Expand Down

0 comments on commit c6e981b

Please sign in to comment.