Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(weave): Add option to name weave evals (and give a memorable name if not specified) #3135

Merged
merged 12 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 20 additions & 2 deletions tests/trace/test_evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from PIL import Image

import weave
from tests.trace.util import AnyIntMatcher
from tests.trace.util import AnyIntMatcher, AnyStrMatcher
from weave import Evaluation, Model
from weave.scorers import Scorer
from weave.trace.refs import CallRef
Expand Down Expand Up @@ -504,8 +504,8 @@ async def test_evaluation_data_topology(client):
}
},
"weave": {
"display_name": AnyStrMatcher(),
"latency_ms": AnyIntMatcher(),
"trace_name": "Evaluation.evaluate",
"status": "success",
},
}
Expand Down Expand Up @@ -1029,3 +1029,21 @@ def my_second_scorer(text, output, model_output):

with pytest.raises(ValueError, match="Both 'output' and 'model_output'"):
evaluation = weave.Evaluation(dataset=ds, scorers=[my_second_scorer])


@pytest.mark.asyncio
async def test_evaluation_with_custom_name(client):
dataset = weave.Dataset(rows=[{"input": "hi", "output": "hello"}])
evaluation = weave.Evaluation(dataset=dataset, evaluation_name="wow-custom!")

@weave.op()
def model(input: str) -> str:
return "hmmm"

await evaluation.evaluate(model)

calls = list(client.get_calls(filter=tsi.CallsFilter(trace_roots_only=True)))
assert len(calls) == 1

call = calls[0]
assert call.display_name == "wow-custom!"
7 changes: 7 additions & 0 deletions tests/trace/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@ def client_is_sqlite(client):
return isinstance(client.server._internal_trace_server, SqliteTraceServer)


class AnyStrMatcher:
"""Matches any string."""

def __eq__(self, other):
return isinstance(other, str)


class AnyIntMatcher:
"""Matches any integer."""

Expand Down
26 changes: 23 additions & 3 deletions weave/flow/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import time
import traceback
from collections.abc import Coroutine
from datetime import datetime
from typing import Any, Callable, Literal, Optional, Union, cast

from pydantic import PrivateAttr
from pydantic import PrivateAttr, model_validator
from rich import print
from rich.console import Console

Expand All @@ -16,6 +17,7 @@
from weave.flow.dataset import Dataset
from weave.flow.model import Model, get_infer_method
from weave.flow.obj import Object
from weave.flow.util import make_memorable_name
from weave.scorers import (
Scorer,
_has_oldstyle_scorers,
Expand All @@ -28,7 +30,7 @@
from weave.trace.env import get_weave_parallelism
from weave.trace.errors import OpCallError
from weave.trace.isinstance import weave_isinstance
from weave.trace.op import Op, as_op, is_op
from weave.trace.op import CallDisplayNameFunc, Op, as_op, is_op
from weave.trace.vals import WeaveObject
from weave.trace.weave_client import Call, get_ref

Expand All @@ -41,6 +43,12 @@
)


def default_evaluation_display_name(call: Call) -> str:
date = datetime.now().strftime("%Y-%m-%d")
unique_name = make_memorable_name()
return f"eval-{date}-{unique_name}"


def async_call(func: Union[Callable, Op], *args: Any, **kwargs: Any) -> Coroutine:
is_async = False
if is_op(func):
Expand Down Expand Up @@ -116,9 +124,21 @@ def function_to_evaluate(question: str):
preprocess_model_input: Optional[Callable] = None
trials: int = 1

# Custom evaluation name for display in the UI. This is the same API as passing a
# custom `call_display_name` to `weave.op` (see that for more details).
evaluation_name: Optional[Union[str, CallDisplayNameFunc]] = None

# internal attr to track whether to use the new `output` or old `model_output` key for outputs
_output_key: Literal["output", "model_output"] = PrivateAttr("output")

@model_validator(mode="after")
def _update_display_name(self) -> "Evaluation":
if self.evaluation_name:
# Treat user-specified `evaluation_name` as the name for `Evaluation.evaluate`
eval_op = as_op(self.evaluate)
eval_op.call_display_name = self.evaluation_name
return self

def model_post_init(self, __context: Any) -> None:
scorers: list[Union[Callable, Scorer, Op]] = []
for scorer in self.scorers or []:
Expand Down Expand Up @@ -486,7 +506,7 @@ async def eval_example(example: dict) -> dict:
eval_rows.append(eval_row)
return EvaluationResults(rows=weave.Table(eval_rows))

@weave.op()
@weave.op(call_display_name=default_evaluation_display_name)
async def evaluate(self, model: Union[Callable, Model]) -> dict:
# The need for this pattern is quite unfortunate and highlights a gap in our
# data model. As a user, I just want to pass a list of data `eval_rows` to
Expand Down
87 changes: 87 additions & 0 deletions weave/flow/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
import multiprocessing
import random
from collections.abc import AsyncIterator, Awaitable, Iterable
from typing import Any, Callable, TypeVar

Expand Down Expand Up @@ -81,3 +82,89 @@ def warn_once(logger: logging.Logger, message: str) -> None:
if message not in _shown_warnings:
logger.warning(message)
_shown_warnings.add(message)


def make_memorable_name() -> str:
adjectives = [
"jubilant",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't really matter but I'd alphabetize

"eager",
"calm",
"bright",
"clever",
"dazzling",
"elegant",
"fierce",
"gentle",
"happy",
"innocent",
"kind",
"lively",
"merry",
"nice",
"proud",
"quiet",
"rich",
"sweet",
"tender",
"unique",
"wise",
"zealous",
"brave",
"charming",
"daring",
"eloquent",
"friendly",
"graceful",
"honest",
"imaginative",
"joyful",
"keen",
"loyal",
"noble",
"optimistic",
]

nouns = [
"sun",
"moon",
"star",
"cloud",
"rain",
"wind",
"tree",
"flower",
"river",
"mountain",
"ocean",
"forest",
"meadow",
"bird",
"wolf",
"bear",
"tiger",
"lion",
"eagle",
"fish",
"whale",
"dolphin",
"rose",
"daisy",
"oak",
"pine",
"maple",
"cedar",
"valley",
"hill",
"lake",
"stream",
"breeze",
"dawn",
"dusk",
"horizon",
"island",
"plateau",
]

adj = random.choice(adjectives)
noun = random.choice(nouns)
return f"{adj}-{noun}"
24 changes: 13 additions & 11 deletions weave/trace/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,15 @@ def _apply_fn_defaults_to_inputs(
) -> dict[str, Any]:
inputs = {**inputs}
sig = inspect.signature(fn)
for param_name, param in sig.parameters.items():
if param_name not in inputs:
if param.default != inspect.Parameter.empty and not _value_is_sentinel(
param
):
inputs[param_name] = param.default
if param.kind == inspect.Parameter.VAR_POSITIONAL:
inputs[param_name] = ()
elif param.kind == inspect.Parameter.VAR_KEYWORD:
inputs[param_name] = {}
for name, param in sig.parameters.items():
if name in inputs:
continue
if param.default != inspect.Parameter.empty and not _value_is_sentinel(param):
inputs[name] = param.default
if param.kind == inspect.Parameter.VAR_POSITIONAL:
inputs[name] = ()
if param.kind == inspect.Parameter.VAR_KEYWORD:
inputs[name] = {}
return inputs


Expand Down Expand Up @@ -216,6 +215,7 @@ def _default_on_input_handler(func: Op, args: tuple, kwargs: dict) -> ProcessedI
inputs = sig.bind(*args, **kwargs).arguments
except TypeError as e:
raise OpCallError(f"Error calling {func.name}: {e}")

inputs_with_defaults = _apply_fn_defaults_to_inputs(func, inputs)
return ProcessedInputs(
original_args=args,
Expand Down Expand Up @@ -736,7 +736,9 @@ def as_op(fn: Callable) -> Op:
if not is_op(fn):
raise ValueError("fn must be a weave.op() decorated function")

return cast(Op, fn)
# The unbinding is necessary for methods because `MethodType` is applied after the
# func is decorated into an Op.
return maybe_unbind_method(cast(Op, fn))


__docspec__ = [call, calls]
Loading