diff --git a/tests/trace/test_evaluations.py b/tests/trace/test_evaluations.py index d20a82db8526..89fe25b7a6c5 100644 --- a/tests/trace/test_evaluations.py +++ b/tests/trace/test_evaluations.py @@ -973,3 +973,37 @@ def score(text, model_output) -> bool: id=list(score.calls())[0].id, ).uri() ) + + +@pytest.mark.asyncio +async def test_feedback_is_correctly_linked_with_scorer_subclass(client): + @weave.op + def predict(text: str) -> str: + return text + + class MyScorer(Scorer): + @weave.op + def score(self, text, output) -> bool: + return text == output + + scorer = MyScorer() + eval = weave.Evaluation( + dataset=[{"text": "hello"}], + scorers=[scorer], + ) + res = await eval.evaluate(predict) + calls = client.server.calls_query( + tsi.CallsQueryReq( + project_id=client._project_id(), + include_feedback=True, + filter=tsi.CallsFilter(op_names=[get_ref(predict).uri()]), + ) + ) + assert len(calls.calls) == 1 + assert calls.calls[0].summary["weave"]["feedback"] + feedbacks = calls.calls[0].summary["weave"]["feedback"] + assert len(feedbacks) == 1 + feedback = feedbacks[0] + assert feedback["feedback_type"] == "wandb.runnable.MyScorer" + assert feedback["payload"] == {"output": True} + assert feedback["runnable_ref"] == get_ref(scorer).uri() diff --git a/weave/flow/eval.py b/weave/flow/eval.py index 25a82b1c8f3d..7ae34c619784 100644 --- a/weave/flow/eval.py +++ b/weave/flow/eval.py @@ -372,7 +372,12 @@ async def predict_and_score( result, score_call = await async_call_op(score_fn, **score_args) wc = get_weave_client() if wc: - wc._send_score_call(model_call, score_call) + # Very important: if the score is generated from a Scorer subclass, + # then scorer_ref_uri will be None, and we will use the op_name from + # the score_call instead. + scorer_ref = get_ref(scorer_self) if scorer_self else None + scorer_ref_uri = scorer_ref.uri() if scorer_ref else None + wc._send_score_call(model_call, score_call, scorer_ref_uri) else: # I would not expect this path to be hit, but keeping it for diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index c0ed78c845c9..d0c5e6ae5341 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -42,6 +42,7 @@ from weave.trace.util import deprecated from weave.trace.vals import WeaveObject, WeaveTable, make_trace_obj from weave.trace_server.ids import generate_id +from weave.trace_server.interface.feedback_types import RUNNABLE_FEEDBACK_TYPE_PREFIX from weave.trace_server.trace_server_interface import ( CallEndReq, CallSchema, @@ -332,12 +333,11 @@ def _apply_scorer(self, scorer_op: Op) -> None: score_call_ref = get_ref(score_call) if score_call_ref is None: raise ValueError("Score call has no ref") - client._add_score( - call_ref_uri=self_ref.uri(), - score_name=score_name, - score_results=score_results, - scorer_call_ref_uri=score_call_ref.uri(), - scorer_op_ref_uri=scorer_op_ref.uri(), + client._add_runnable_feedback( + weave_ref_uri=self_ref.uri(), + output=score_results, + call_ref_uri=score_call_ref.uri(), + runnable_ref_uri=scorer_op_ref.uri(), ) @@ -1105,7 +1105,12 @@ def query_costs( return res.results @trace_sentry.global_trace_sentry.watch() - def _send_score_call(self, predict_call: Call, score_call: Call) -> Future[str]: + def _send_score_call( + self, + predict_call: Call, + score_call: Call, + scorer_object_ref_uri: Optional[str] = None, + ) -> Future[str]: """(Private) Adds a score to a call. This is particularly useful for adding evaluation metrics to a call. """ @@ -1114,37 +1119,35 @@ def send_score_call() -> str: call_ref = get_ref(predict_call) if call_ref is None: raise ValueError("Predict call must have a ref") - call_ref_uri = call_ref.uri() + weave_ref_uri = call_ref.uri() scorer_call_ref = get_ref(score_call) if scorer_call_ref is None: raise ValueError("Score call must have a ref") scorer_call_ref_uri = scorer_call_ref.uri() - scorer_op_ref_uri = score_call.op_name - scorer_op_ref = parse_uri(scorer_op_ref_uri) - if not isinstance(scorer_op_ref, OpRef): - raise TypeError(f"Invalid scorer op ref: {scorer_op_ref_uri}") - score_name = scorer_op_ref.name + + # If scorer_object_ref_uri is provided, it is used as the runnable_ref_uri + # Otherwise, we use the op_name from the score_call. This should happen + # when there is a Scorer subclass that is the source of the score call. + runnable_ref_uri = scorer_object_ref_uri or score_call.op_name score_results = score_call.output - return self._add_score( - call_ref_uri=call_ref_uri, - score_name=score_name, - score_results=score_results, - scorer_call_ref_uri=scorer_call_ref_uri, - scorer_op_ref_uri=scorer_op_ref_uri, + return self._add_runnable_feedback( + weave_ref_uri=weave_ref_uri, + output=score_results, + call_ref_uri=scorer_call_ref_uri, + runnable_ref_uri=runnable_ref_uri, ) return self.future_executor.defer(send_score_call) @trace_sentry.global_trace_sentry.watch() - def _add_score( + def _add_runnable_feedback( self, *, + weave_ref_uri: str, + output: Any, call_ref_uri: str, - score_name: str, - score_results: Any, - scorer_call_ref_uri: str, - scorer_op_ref_uri: str, + runnable_ref_uri: str, # , supervision: dict ) -> str: """(Private) Low-level, non object-oriented method for adding a score to a call. @@ -1153,25 +1156,19 @@ def _add_score( - Should we somehow include supervision (ie. the ground truth) in the payload? """ # Parse the refs (acts as validation) - call_ref = parse_uri(call_ref_uri) + call_ref = parse_uri(weave_ref_uri) if not isinstance(call_ref, CallRef): - raise TypeError(f"Invalid call ref: {call_ref_uri}") - scorer_call_ref = parse_uri(scorer_call_ref_uri) + raise TypeError(f"Invalid call ref: {weave_ref_uri}") + scorer_call_ref = parse_uri(call_ref_uri) if not isinstance(scorer_call_ref, CallRef): - raise TypeError(f"Invalid scorer call ref: {scorer_call_ref_uri}") - scorer_op_ref = parse_uri(scorer_op_ref_uri) - if not isinstance(scorer_op_ref, OpRef): - raise TypeError(f"Invalid scorer op ref: {scorer_op_ref_uri}") - - # Validate score_name (we might want to relax this in the future) - if score_name != scorer_op_ref.name: - raise ValueError( - f"Score name {score_name} does not match scorer op name {scorer_op_ref.name}" - ) + raise TypeError(f"Invalid scorer call ref: {call_ref_uri}") + runnable_ref = parse_uri(runnable_ref_uri) + if not isinstance(runnable_ref, (OpRef, ObjectRef)): + raise TypeError(f"Invalid scorer op ref: {runnable_ref_uri}") # Prepare the result payload - we purposely do not map to refs here # because we prefer to have the raw data. - results_json = to_json(score_results, self._project_id(), self) + results_json = to_json(output, self._project_id(), self) # # Prepare the supervision payload @@ -1181,11 +1178,11 @@ def _add_score( freq = FeedbackCreateReq( project_id=self._project_id(), - weave_ref=call_ref_uri, - feedback_type="wandb.runnable." + score_name, + weave_ref=weave_ref_uri, + feedback_type=RUNNABLE_FEEDBACK_TYPE_PREFIX + "." + runnable_ref.name, payload=payload, - runnable_ref=scorer_op_ref_uri, - call_ref=scorer_call_ref_uri, + runnable_ref=runnable_ref_uri, + call_ref=call_ref_uri, ) response = self.server.feedback_create(freq)