Skip to content

Commit

Permalink
chore(weave): Class-based scorer runnable ref fix (#2922)
Browse files Browse the repository at this point in the history
* init

* Done except tests

* fixed params

* added tests

* comments
  • Loading branch information
tssweeney authored Nov 7, 2024
1 parent 12222ac commit 6822779
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 43 deletions.
34 changes: 34 additions & 0 deletions tests/trace/test_evaluations.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,3 +973,37 @@ def score(text, model_output) -> bool:
id=list(score.calls())[0].id,
).uri()
)


@pytest.mark.asyncio
async def test_feedback_is_correctly_linked_with_scorer_subclass(client):
@weave.op
def predict(text: str) -> str:
return text

class MyScorer(Scorer):
@weave.op
def score(self, text, output) -> bool:
return text == output

scorer = MyScorer()
eval = weave.Evaluation(
dataset=[{"text": "hello"}],
scorers=[scorer],
)
res = await eval.evaluate(predict)
calls = client.server.calls_query(
tsi.CallsQueryReq(
project_id=client._project_id(),
include_feedback=True,
filter=tsi.CallsFilter(op_names=[get_ref(predict).uri()]),
)
)
assert len(calls.calls) == 1
assert calls.calls[0].summary["weave"]["feedback"]
feedbacks = calls.calls[0].summary["weave"]["feedback"]
assert len(feedbacks) == 1
feedback = feedbacks[0]
assert feedback["feedback_type"] == "wandb.runnable.MyScorer"
assert feedback["payload"] == {"output": True}
assert feedback["runnable_ref"] == get_ref(scorer).uri()
7 changes: 6 additions & 1 deletion weave/flow/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,12 @@ async def predict_and_score(
result, score_call = await async_call_op(score_fn, **score_args)
wc = get_weave_client()
if wc:
wc._send_score_call(model_call, score_call)
# Very important: if the score is generated from a Scorer subclass,
# then scorer_ref_uri will be None, and we will use the op_name from
# the score_call instead.
scorer_ref = get_ref(scorer_self) if scorer_self else None
scorer_ref_uri = scorer_ref.uri() if scorer_ref else None
wc._send_score_call(model_call, score_call, scorer_ref_uri)

else:
# I would not expect this path to be hit, but keeping it for
Expand Down
81 changes: 39 additions & 42 deletions weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from weave.trace.util import deprecated
from weave.trace.vals import WeaveObject, WeaveTable, make_trace_obj
from weave.trace_server.ids import generate_id
from weave.trace_server.interface.feedback_types import RUNNABLE_FEEDBACK_TYPE_PREFIX
from weave.trace_server.trace_server_interface import (
CallEndReq,
CallSchema,
Expand Down Expand Up @@ -332,12 +333,11 @@ def _apply_scorer(self, scorer_op: Op) -> None:
score_call_ref = get_ref(score_call)
if score_call_ref is None:
raise ValueError("Score call has no ref")
client._add_score(
call_ref_uri=self_ref.uri(),
score_name=score_name,
score_results=score_results,
scorer_call_ref_uri=score_call_ref.uri(),
scorer_op_ref_uri=scorer_op_ref.uri(),
client._add_runnable_feedback(
weave_ref_uri=self_ref.uri(),
output=score_results,
call_ref_uri=score_call_ref.uri(),
runnable_ref_uri=scorer_op_ref.uri(),
)


Expand Down Expand Up @@ -1105,7 +1105,12 @@ def query_costs(
return res.results

@trace_sentry.global_trace_sentry.watch()
def _send_score_call(self, predict_call: Call, score_call: Call) -> Future[str]:
def _send_score_call(
self,
predict_call: Call,
score_call: Call,
scorer_object_ref_uri: Optional[str] = None,
) -> Future[str]:
"""(Private) Adds a score to a call. This is particularly useful
for adding evaluation metrics to a call.
"""
Expand All @@ -1114,37 +1119,35 @@ def send_score_call() -> str:
call_ref = get_ref(predict_call)
if call_ref is None:
raise ValueError("Predict call must have a ref")
call_ref_uri = call_ref.uri()
weave_ref_uri = call_ref.uri()
scorer_call_ref = get_ref(score_call)
if scorer_call_ref is None:
raise ValueError("Score call must have a ref")
scorer_call_ref_uri = scorer_call_ref.uri()
scorer_op_ref_uri = score_call.op_name
scorer_op_ref = parse_uri(scorer_op_ref_uri)
if not isinstance(scorer_op_ref, OpRef):
raise TypeError(f"Invalid scorer op ref: {scorer_op_ref_uri}")
score_name = scorer_op_ref.name

# If scorer_object_ref_uri is provided, it is used as the runnable_ref_uri
# Otherwise, we use the op_name from the score_call. This should happen
# when there is a Scorer subclass that is the source of the score call.
runnable_ref_uri = scorer_object_ref_uri or score_call.op_name
score_results = score_call.output

return self._add_score(
call_ref_uri=call_ref_uri,
score_name=score_name,
score_results=score_results,
scorer_call_ref_uri=scorer_call_ref_uri,
scorer_op_ref_uri=scorer_op_ref_uri,
return self._add_runnable_feedback(
weave_ref_uri=weave_ref_uri,
output=score_results,
call_ref_uri=scorer_call_ref_uri,
runnable_ref_uri=runnable_ref_uri,
)

return self.future_executor.defer(send_score_call)

@trace_sentry.global_trace_sentry.watch()
def _add_score(
def _add_runnable_feedback(
self,
*,
weave_ref_uri: str,
output: Any,
call_ref_uri: str,
score_name: str,
score_results: Any,
scorer_call_ref_uri: str,
scorer_op_ref_uri: str,
runnable_ref_uri: str,
# , supervision: dict
) -> str:
"""(Private) Low-level, non object-oriented method for adding a score to a call.
Expand All @@ -1153,25 +1156,19 @@ def _add_score(
- Should we somehow include supervision (ie. the ground truth) in the payload?
"""
# Parse the refs (acts as validation)
call_ref = parse_uri(call_ref_uri)
call_ref = parse_uri(weave_ref_uri)
if not isinstance(call_ref, CallRef):
raise TypeError(f"Invalid call ref: {call_ref_uri}")
scorer_call_ref = parse_uri(scorer_call_ref_uri)
raise TypeError(f"Invalid call ref: {weave_ref_uri}")
scorer_call_ref = parse_uri(call_ref_uri)
if not isinstance(scorer_call_ref, CallRef):
raise TypeError(f"Invalid scorer call ref: {scorer_call_ref_uri}")
scorer_op_ref = parse_uri(scorer_op_ref_uri)
if not isinstance(scorer_op_ref, OpRef):
raise TypeError(f"Invalid scorer op ref: {scorer_op_ref_uri}")

# Validate score_name (we might want to relax this in the future)
if score_name != scorer_op_ref.name:
raise ValueError(
f"Score name {score_name} does not match scorer op name {scorer_op_ref.name}"
)
raise TypeError(f"Invalid scorer call ref: {call_ref_uri}")
runnable_ref = parse_uri(runnable_ref_uri)
if not isinstance(runnable_ref, (OpRef, ObjectRef)):
raise TypeError(f"Invalid scorer op ref: {runnable_ref_uri}")

# Prepare the result payload - we purposely do not map to refs here
# because we prefer to have the raw data.
results_json = to_json(score_results, self._project_id(), self)
results_json = to_json(output, self._project_id(), self)

# # Prepare the supervision payload

Expand All @@ -1181,11 +1178,11 @@ def _add_score(

freq = FeedbackCreateReq(
project_id=self._project_id(),
weave_ref=call_ref_uri,
feedback_type="wandb.runnable." + score_name,
weave_ref=weave_ref_uri,
feedback_type=RUNNABLE_FEEDBACK_TYPE_PREFIX + "." + runnable_ref.name,
payload=payload,
runnable_ref=scorer_op_ref_uri,
call_ref=scorer_call_ref_uri,
runnable_ref=runnable_ref_uri,
call_ref=call_ref_uri,
)
response = self.server.feedback_create(freq)

Expand Down

0 comments on commit 6822779

Please sign in to comment.