Skip to content

Commit

Permalink
add fluency
Browse files Browse the repository at this point in the history
  • Loading branch information
tcapelle committed Jan 27, 2025
1 parent 21294c3 commit a1b01f8
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 8 deletions.
37 changes: 37 additions & 0 deletions weave/scorers/fluency_scorer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import weave
from weave.scorers.llm_scorer import HuggingFacePipelineScorer
from weave.llm_utils import set_device


class FluencyScorer(HuggingFacePipelineScorer):
"""
W&B Fluency Scorer
This scorer uses an in-house model to score fluency based on ModernBert.
Example:
>>> from weave.scorers.fluency_scorer import FluencyScorer
>>> scorer = FluencyScorer()
>>> result = scorer.score("This text is fluent.")
>>> print(result)
{
'flagged': True,
}
"""
model_name_or_path = "tcapelle/fluency-scorer" # TODO: replace with and artifact
device = "auto"


def load_pipeline(self) -> None:
from transformers import pipeline
self.device = set_device(self.device)
self._pipeline = pipeline(
"text-classification",
model=self.model_name_or_path,
device=self.device
)

@weave.op
async def score(self, output: str):
pipeline_output = self._pipeline(output)
return {"flagged": pipeline_output[0]["label"] == "fluent"}
11 changes: 3 additions & 8 deletions weave/scorers/llm_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,6 @@ class HuggingFacePipelineScorer(Scorer):
task (str): The pipeline task type (e.g., `"text-classification"`).
model_name_or_path (str): The name or path of the model to use.
device (str): The device to use for inference. Defaults to `"cpu"`.
pipeline_kwargs (dict[str, Any]): Additional keyword arguments for the pipeline. Defaults to `{}`.
Returns:
list[dict[str, Any]]: The pipeline's output after processing the input text.
Example:
>>> from weave.scorers.moderation_scorer import PipelineScorer
Expand All @@ -109,8 +105,7 @@ class HuggingFacePipelineScorer(Scorer):
)
model_name_or_path: str = Field(default="", description="The path to the model")
device: str = Field(default="auto", description="The device to use for the model")
pipeline_kwargs: dict[str, Any] = Field(default_factory=dict)
pipeline: Optional[Any] = None
_pipeline: Optional[Any] = None

def model_post_init(self, __context: Any) -> None:
self.device = set_device(self.device)
Expand All @@ -123,8 +118,8 @@ def model_post_init(self, __context: Any) -> None:
print(
"The `transformers` package is required to use PipelineScorer, please run `pip install transformers`"
)
if self.pipeline is None:
self.set_pipeline()
if self._pipeline is None:
self.load_pipeline()

def load_pipeline(self) -> None:
raise NotImplementedError(
Expand Down

0 comments on commit a1b01f8

Please sign in to comment.