From cbb592449bcdcb8d308dacd9ee4259f0e80ecc53 Mon Sep 17 00:00:00 2001 From: Adrian Swanberg Date: Tue, 14 Jan 2025 14:38:47 -0800 Subject: [PATCH 1/3] Change Weave to W&B Weave on docs site --- docs/docusaurus.config.ts | 4 ++-- weave/trace/weave_client.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/docs/docusaurus.config.ts b/docs/docusaurus.config.ts index 9616064c73e9..4814856ce5de 100644 --- a/docs/docusaurus.config.ts +++ b/docs/docusaurus.config.ts @@ -125,7 +125,7 @@ const config: Config = { // Replace with your project's social card image: "img/logo-large-padded.png", navbar: { - title: "Weave", + title: "W&B Weave", logo: { alt: "My Site Logo", src: "img/logo.svg", @@ -227,7 +227,7 @@ const config: Config = { ], }, ], - copyright: `Weave by W&B`, + copyright: `Made with ❤️ by Weights & Biases`, }, prism: { // theme: prismThemes.nightOwl, diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 1d5d54b9b23c..f665beb7b871 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -792,6 +792,7 @@ def create_call( attributes._set_weave_item("os_version", platform.version()) attributes._set_weave_item("os_release", platform.release()) attributes._set_weave_item("sys_version", sys.version) + attributes._set_weave_item("tracing_sample_rate", op.tracing_sample_rate) op_name_future = self.future_executor.defer(lambda: op_def_ref.uri()) From 803fa13b9a8bec56f694d891b911466c108e2ad4 Mon Sep 17 00:00:00 2001 From: Adrian Swanberg Date: Tue, 14 Jan 2025 14:42:42 -0800 Subject: [PATCH 2/3] revert weave_client changes --- weave/trace/weave_client.py | 179 +++++++++--------------------------- 1 file changed, 43 insertions(+), 136 deletions(-) diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index e9ba03aca801..1d5d54b9b23c 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -2,7 +2,7 @@ import dataclasses import datetime -import json +import inspect import logging import platform import re @@ -10,16 +10,7 @@ from collections.abc import Iterator, Sequence from concurrent.futures import Future from functools import lru_cache -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Generic, - Protocol, - TypeVar, - cast, - overload, -) +from typing import Any, Callable, Generic, Protocol, TypeVar, cast, overload import pydantic from requests import HTTPError @@ -31,7 +22,6 @@ from weave.trace.context import weave_client_context as weave_client_context from weave.trace.exception import exception_to_json_str from weave.trace.feedback import FeedbackQuery, RefFeedbackQuery -from weave.trace.isinstance import weave_isinstance from weave.trace.object_record import ( ObjectRecord, dataclass_object_record, @@ -48,7 +38,7 @@ parse_op_uri, parse_uri, ) -from weave.trace.sanitize import REDACTED_VALUE, should_redact +from weave.trace.sanitize import REDACT_KEYS, REDACTED_VALUE from weave.trace.serialize import from_json, isinstance_namedtuple, to_json from weave.trace.serializer import get_serializer_for_obj from weave.trace.settings import client_parallelism @@ -64,7 +54,6 @@ CallsDeleteReq, CallsFilter, CallsQueryReq, - CallsQueryStatsReq, CallStartReq, CallUpdateReq, CostCreateInput, @@ -79,7 +68,6 @@ FileCreateRes, ObjCreateReq, ObjCreateRes, - ObjDeleteReq, ObjectVersionFilter, ObjQueryReq, ObjReadReq, @@ -95,10 +83,6 @@ ) from weave.trace_server_bindings.remote_http_trace_server import RemoteHTTPTraceServer -if TYPE_CHECKING: - from weave.scorers.base_scorer import ApplyScorerResult, Scorer - - # Controls if objects can have refs to projects not the WeaveClient project. # If False, object refs with with mismatching projects will be recreated. # If True, use existing ref to object in other project. @@ -115,7 +99,6 @@ def __call__(self, offset: int, limit: int) -> list[T]: ... TransformFunc = Callable[[T], R] -SizeFunc = Callable[[], int] class PaginatedIterator(Generic[T, R]): @@ -127,12 +110,10 @@ def __init__( fetch_func: FetchFunc[T], page_size: int = 1000, transform_func: TransformFunc[T, R] | None = None, - size_func: SizeFunc | None = None, ) -> None: self.fetch_func = fetch_func self.page_size = page_size self.transform_func = transform_func - self.size_func = size_func if page_size <= 0: raise ValueError("page_size must be greater than 0") @@ -201,13 +182,6 @@ def __iter__(self: PaginatedIterator[T, R]) -> Iterator[R]: ... def __iter__(self) -> Iterator[T] | Iterator[R]: return self._get_slice(slice(0, None, 1)) - def __len__(self) -> int: - """This method is included for convenience. It includes a network call, which - is typically slower than most other len() operations!""" - if not self.size_func: - raise TypeError("This iterator does not support len()") - return self.size_func() - # TODO: should be Call, not WeaveObject CallsIter = PaginatedIterator[CallSchema, WeaveObject] @@ -236,17 +210,7 @@ def transform_func(call: CallSchema) -> WeaveObject: entity, project = project_id.split("/") return make_client_call(entity, project, call, server) - def size_func() -> int: - response = server.calls_query_stats( - CallsQueryStatsReq(project_id=project_id, filter=filter) - ) - return response.count - - return PaginatedIterator( - fetch_func, - transform_func=transform_func, - size_func=size_func, - ) + return PaginatedIterator(fetch_func, transform_func=transform_func) class OpNameError(ValueError): @@ -481,60 +445,43 @@ def set_display_name(self, name: str | None) -> None: def remove_display_name(self) -> None: self.set_display_name(None) - async def apply_scorer( - self, scorer: Op | Scorer, additional_scorer_kwargs: dict | None = None - ) -> ApplyScorerResult: + def _apply_scorer(self, scorer_op: Op) -> None: """ - `apply_scorer` is a method that applies a Scorer to a Call. This is useful - for guarding application logic with a scorer and/or monitoring the quality - of critical ops. Scorers are automatically logged to Weave as Feedback and - can be used in queries & analysis. - - Args: - scorer: The Scorer to apply. - additional_scorer_kwargs: Additional kwargs to pass to the scorer. This is - useful for passing in additional context that is not part of the call - inputs.useful for passing in additional context that is not part of the call - inputs. + This is a private method that applies a scorer to a call and records the feedback. + In the near future, this will be made public, but for now it is only used internally + for testing. - Returns: - The result of the scorer application in the form of an `ApplyScorerResult`. - - ```python - class ApplyScorerSuccess: - result: Any - score_call: Call - ``` - - Example usage: + Before making this public, we should refactor such that the `predict_and_score` method + inside `eval.py` uses this method inside the scorer block. - ```python - my_scorer = ... # construct a scorer - prediction, prediction_call = my_op.call(input_data) - result, score_call = prediction.apply_scorer(my_scorer) - ``` + Current limitations: + - only works for ops (not Scorer class) + - no async support + - no context yet (ie. ground truth) """ - from weave.scorers.base_scorer import Scorer, apply_scorer_async - - model_inputs = {k: v for k, v in self.inputs.items() if k != "self"} - example = {**model_inputs, **(additional_scorer_kwargs or {})} - output = self.output - if isinstance(output, ObjectRef): - output = output.get() - apply_scorer_result = await apply_scorer_async(scorer, example, output) - score_call = apply_scorer_result.score_call - - wc = weave_client_context.get_weave_client() - if wc: - scorer_ref_uri = None - if weave_isinstance(scorer, Scorer): - # Very important: if the score is generated from a Scorer subclass, - # then scorer_ref_uri will be None, and we will use the op_name from - # the score_call instead. - scorer_ref = get_ref(scorer) - scorer_ref_uri = scorer_ref.uri() if scorer_ref else None - wc._send_score_call(self, score_call, scorer_ref_uri) - return apply_scorer_result + client = weave_client_context.require_weave_client() + scorer_signature = inspect.signature(scorer_op) + scorer_arg_names = list(scorer_signature.parameters.keys()) + score_args = {k: v for k, v in self.inputs.items() if k in scorer_arg_names} + if "output" in scorer_arg_names: + score_args["output"] = self.output + _, score_call = scorer_op.call(**score_args) + scorer_op_ref = get_ref(scorer_op) + if scorer_op_ref is None: + raise ValueError("Scorer op has no ref") + self_ref = get_ref(self) + if self_ref is None: + raise ValueError("Call has no ref") + score_results = score_call.output + score_call_ref = get_ref(score_call) + if score_call_ref is None: + raise ValueError("Score call has no ref") + client._add_runnable_feedback( + weave_ref_uri=self_ref.uri(), + output=score_results, + call_ref_uri=score_call_ref.uri(), + runnable_ref_uri=scorer_op_ref.uri(), + ) def make_client_call( @@ -697,15 +644,8 @@ def get(self, ref: ObjectRef) -> Any: ) ) except HTTPError as e: - if e.response is not None: - if e.response.content: - try: - reason = json.loads(e.response.content).get("reason") - raise ValueError(reason) - except json.JSONDecodeError: - raise ValueError(e.response.content) - if e.response.status_code == 404: - raise ValueError(f"Unable to find object for ref uri: {ref.uri()}") + if e.response is not None and e.response.status_code == 404: + raise ValueError(f"Unable to find object for ref uri: {ref.uri()}") raise # At this point, `ref.digest` is one of three things: @@ -815,8 +755,6 @@ def create_call( Returns: The created Call object. """ - from weave.trace.api import _global_postprocess_inputs - if isinstance(op, str): if op not in self._anonymous_ops: self._anonymous_ops[op] = _build_anonymous_op(op) @@ -831,9 +769,6 @@ def create_call( else: inputs_postprocessed = inputs_redacted - if _global_postprocess_inputs: - inputs_postprocessed = _global_postprocess_inputs(inputs_postprocessed) - self._save_nested_objects(inputs_postprocessed) inputs_with_refs = map_to_refs(inputs_postprocessed) @@ -857,7 +792,6 @@ def create_call( attributes._set_weave_item("os_version", platform.version()) attributes._set_weave_item("os_release", platform.release()) attributes._set_weave_item("sys_version", sys.version) - attributes._set_weave_item("tracing_sample_rate", op.tracing_sample_rate) op_name_future = self.future_executor.defer(lambda: op_def_ref.uri()) @@ -868,7 +802,6 @@ def create_call( trace_id=trace_id, parent_id=parent_id, id=call_id, - # It feels like this should be inputs_postprocessed, not the refs. inputs=inputs_with_refs, attributes=attributes, ) @@ -922,8 +855,6 @@ def finish_call( *, op: Op | None = None, ) -> None: - from weave.trace.api import _global_postprocess_output - ended_at = datetime.datetime.now(tz=datetime.timezone.utc) call.ended_at = ended_at original_output = output @@ -932,13 +863,9 @@ def finish_call( postprocessed_output = op.postprocess_output(original_output) else: postprocessed_output = original_output - - if _global_postprocess_output: - postprocessed_output = _global_postprocess_output(postprocessed_output) - self._save_nested_objects(postprocessed_output) - output_as_refs = map_to_refs(postprocessed_output) - call.output = postprocessed_output + + call.output = map_to_refs(postprocessed_output) # Summary handling summary = {} @@ -993,7 +920,7 @@ def finish_call( op._on_finish_handler(call, original_output, exception) def send_end_call() -> None: - output_json = to_json(output_as_refs, project_id, self, use_dictify=False) + output_json = to_json(call.output, project_id, self, use_dictify=False) self.server.call_end( CallEndReq( end=EndedCallSchemaForInsert( @@ -1025,26 +952,6 @@ def delete_call(self, call: Call) -> None: ) ) - @trace_sentry.global_trace_sentry.watch() - def delete_object_version(self, object: ObjectRef) -> None: - self.server.obj_delete( - ObjDeleteReq( - project_id=self._project_id(), - object_id=object.name, - digests=[object.digest], - ) - ) - - @trace_sentry.global_trace_sentry.watch() - def delete_op_version(self, op: OpRef) -> None: - self.server.obj_delete( - ObjDeleteReq( - project_id=self._project_id(), - object_id=op.name, - digests=[op.digest], - ) - ) - def get_feedback( self, query: Query | str | None = None, @@ -1741,7 +1648,7 @@ def redact_sensitive_keys(obj: Any) -> Any: if isinstance(obj, dict): dict_res = {} for k, v in obj.items(): - if isinstance(k, str) and should_redact(k): + if k in REDACT_KEYS: dict_res[k] = REDACTED_VALUE else: dict_res[k] = redact_sensitive_keys(v) From 649bfd339930ba215940851e603aeaebbca43d3e Mon Sep 17 00:00:00 2001 From: Adrian Swanberg Date: Tue, 14 Jan 2025 14:46:31 -0800 Subject: [PATCH 3/3] actually revert weave_client changes --- weave/trace/weave_client.py | 178 +++++++++++++++++++++++++++--------- 1 file changed, 135 insertions(+), 43 deletions(-) diff --git a/weave/trace/weave_client.py b/weave/trace/weave_client.py index 1d5d54b9b23c..6e5d36695936 100644 --- a/weave/trace/weave_client.py +++ b/weave/trace/weave_client.py @@ -2,7 +2,7 @@ import dataclasses import datetime -import inspect +import json import logging import platform import re @@ -10,7 +10,16 @@ from collections.abc import Iterator, Sequence from concurrent.futures import Future from functools import lru_cache -from typing import Any, Callable, Generic, Protocol, TypeVar, cast, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Protocol, + TypeVar, + cast, + overload, +) import pydantic from requests import HTTPError @@ -22,6 +31,7 @@ from weave.trace.context import weave_client_context as weave_client_context from weave.trace.exception import exception_to_json_str from weave.trace.feedback import FeedbackQuery, RefFeedbackQuery +from weave.trace.isinstance import weave_isinstance from weave.trace.object_record import ( ObjectRecord, dataclass_object_record, @@ -38,7 +48,7 @@ parse_op_uri, parse_uri, ) -from weave.trace.sanitize import REDACT_KEYS, REDACTED_VALUE +from weave.trace.sanitize import REDACTED_VALUE, should_redact from weave.trace.serialize import from_json, isinstance_namedtuple, to_json from weave.trace.serializer import get_serializer_for_obj from weave.trace.settings import client_parallelism @@ -54,6 +64,7 @@ CallsDeleteReq, CallsFilter, CallsQueryReq, + CallsQueryStatsReq, CallStartReq, CallUpdateReq, CostCreateInput, @@ -68,6 +79,7 @@ FileCreateRes, ObjCreateReq, ObjCreateRes, + ObjDeleteReq, ObjectVersionFilter, ObjQueryReq, ObjReadReq, @@ -83,6 +95,10 @@ ) from weave.trace_server_bindings.remote_http_trace_server import RemoteHTTPTraceServer +if TYPE_CHECKING: + from weave.scorers.base_scorer import ApplyScorerResult, Scorer + + # Controls if objects can have refs to projects not the WeaveClient project. # If False, object refs with with mismatching projects will be recreated. # If True, use existing ref to object in other project. @@ -99,6 +115,7 @@ def __call__(self, offset: int, limit: int) -> list[T]: ... TransformFunc = Callable[[T], R] +SizeFunc = Callable[[], int] class PaginatedIterator(Generic[T, R]): @@ -110,10 +127,12 @@ def __init__( fetch_func: FetchFunc[T], page_size: int = 1000, transform_func: TransformFunc[T, R] | None = None, + size_func: SizeFunc | None = None, ) -> None: self.fetch_func = fetch_func self.page_size = page_size self.transform_func = transform_func + self.size_func = size_func if page_size <= 0: raise ValueError("page_size must be greater than 0") @@ -182,6 +201,13 @@ def __iter__(self: PaginatedIterator[T, R]) -> Iterator[R]: ... def __iter__(self) -> Iterator[T] | Iterator[R]: return self._get_slice(slice(0, None, 1)) + def __len__(self) -> int: + """This method is included for convenience. It includes a network call, which + is typically slower than most other len() operations!""" + if not self.size_func: + raise TypeError("This iterator does not support len()") + return self.size_func() + # TODO: should be Call, not WeaveObject CallsIter = PaginatedIterator[CallSchema, WeaveObject] @@ -210,7 +236,17 @@ def transform_func(call: CallSchema) -> WeaveObject: entity, project = project_id.split("/") return make_client_call(entity, project, call, server) - return PaginatedIterator(fetch_func, transform_func=transform_func) + def size_func() -> int: + response = server.calls_query_stats( + CallsQueryStatsReq(project_id=project_id, filter=filter) + ) + return response.count + + return PaginatedIterator( + fetch_func, + transform_func=transform_func, + size_func=size_func, + ) class OpNameError(ValueError): @@ -445,43 +481,60 @@ def set_display_name(self, name: str | None) -> None: def remove_display_name(self) -> None: self.set_display_name(None) - def _apply_scorer(self, scorer_op: Op) -> None: + async def apply_scorer( + self, scorer: Op | Scorer, additional_scorer_kwargs: dict | None = None + ) -> ApplyScorerResult: """ - This is a private method that applies a scorer to a call and records the feedback. - In the near future, this will be made public, but for now it is only used internally - for testing. + `apply_scorer` is a method that applies a Scorer to a Call. This is useful + for guarding application logic with a scorer and/or monitoring the quality + of critical ops. Scorers are automatically logged to Weave as Feedback and + can be used in queries & analysis. + + Args: + scorer: The Scorer to apply. + additional_scorer_kwargs: Additional kwargs to pass to the scorer. This is + useful for passing in additional context that is not part of the call + inputs.useful for passing in additional context that is not part of the call + inputs. - Before making this public, we should refactor such that the `predict_and_score` method - inside `eval.py` uses this method inside the scorer block. + Returns: + The result of the scorer application in the form of an `ApplyScorerResult`. + + ```python + class ApplyScorerSuccess: + result: Any + score_call: Call + ``` + + Example usage: - Current limitations: - - only works for ops (not Scorer class) - - no async support - - no context yet (ie. ground truth) + ```python + my_scorer = ... # construct a scorer + prediction, prediction_call = my_op.call(input_data) + result, score_call = prediction.apply_scorer(my_scorer) + ``` """ - client = weave_client_context.require_weave_client() - scorer_signature = inspect.signature(scorer_op) - scorer_arg_names = list(scorer_signature.parameters.keys()) - score_args = {k: v for k, v in self.inputs.items() if k in scorer_arg_names} - if "output" in scorer_arg_names: - score_args["output"] = self.output - _, score_call = scorer_op.call(**score_args) - scorer_op_ref = get_ref(scorer_op) - if scorer_op_ref is None: - raise ValueError("Scorer op has no ref") - self_ref = get_ref(self) - if self_ref is None: - raise ValueError("Call has no ref") - score_results = score_call.output - score_call_ref = get_ref(score_call) - if score_call_ref is None: - raise ValueError("Score call has no ref") - client._add_runnable_feedback( - weave_ref_uri=self_ref.uri(), - output=score_results, - call_ref_uri=score_call_ref.uri(), - runnable_ref_uri=scorer_op_ref.uri(), - ) + from weave.scorers.base_scorer import Scorer, apply_scorer_async + + model_inputs = {k: v for k, v in self.inputs.items() if k != "self"} + example = {**model_inputs, **(additional_scorer_kwargs or {})} + output = self.output + if isinstance(output, ObjectRef): + output = output.get() + apply_scorer_result = await apply_scorer_async(scorer, example, output) + score_call = apply_scorer_result.score_call + + wc = weave_client_context.get_weave_client() + if wc: + scorer_ref_uri = None + if weave_isinstance(scorer, Scorer): + # Very important: if the score is generated from a Scorer subclass, + # then scorer_ref_uri will be None, and we will use the op_name from + # the score_call instead. + scorer_ref = get_ref(scorer) + scorer_ref_uri = scorer_ref.uri() if scorer_ref else None + wc._send_score_call(self, score_call, scorer_ref_uri) + return apply_scorer_result def make_client_call( @@ -644,8 +697,15 @@ def get(self, ref: ObjectRef) -> Any: ) ) except HTTPError as e: - if e.response is not None and e.response.status_code == 404: - raise ValueError(f"Unable to find object for ref uri: {ref.uri()}") + if e.response is not None: + if e.response.content: + try: + reason = json.loads(e.response.content).get("reason") + raise ValueError(reason) + except json.JSONDecodeError: + raise ValueError(e.response.content) + if e.response.status_code == 404: + raise ValueError(f"Unable to find object for ref uri: {ref.uri()}") raise # At this point, `ref.digest` is one of three things: @@ -755,6 +815,8 @@ def create_call( Returns: The created Call object. """ + from weave.trace.api import _global_postprocess_inputs + if isinstance(op, str): if op not in self._anonymous_ops: self._anonymous_ops[op] = _build_anonymous_op(op) @@ -769,6 +831,9 @@ def create_call( else: inputs_postprocessed = inputs_redacted + if _global_postprocess_inputs: + inputs_postprocessed = _global_postprocess_inputs(inputs_postprocessed) + self._save_nested_objects(inputs_postprocessed) inputs_with_refs = map_to_refs(inputs_postprocessed) @@ -802,6 +867,7 @@ def create_call( trace_id=trace_id, parent_id=parent_id, id=call_id, + # It feels like this should be inputs_postprocessed, not the refs. inputs=inputs_with_refs, attributes=attributes, ) @@ -855,6 +921,8 @@ def finish_call( *, op: Op | None = None, ) -> None: + from weave.trace.api import _global_postprocess_output + ended_at = datetime.datetime.now(tz=datetime.timezone.utc) call.ended_at = ended_at original_output = output @@ -863,9 +931,13 @@ def finish_call( postprocessed_output = op.postprocess_output(original_output) else: postprocessed_output = original_output - self._save_nested_objects(postprocessed_output) - call.output = map_to_refs(postprocessed_output) + if _global_postprocess_output: + postprocessed_output = _global_postprocess_output(postprocessed_output) + + self._save_nested_objects(postprocessed_output) + output_as_refs = map_to_refs(postprocessed_output) + call.output = postprocessed_output # Summary handling summary = {} @@ -920,7 +992,7 @@ def finish_call( op._on_finish_handler(call, original_output, exception) def send_end_call() -> None: - output_json = to_json(call.output, project_id, self, use_dictify=False) + output_json = to_json(output_as_refs, project_id, self, use_dictify=False) self.server.call_end( CallEndReq( end=EndedCallSchemaForInsert( @@ -952,6 +1024,26 @@ def delete_call(self, call: Call) -> None: ) ) + @trace_sentry.global_trace_sentry.watch() + def delete_object_version(self, object: ObjectRef) -> None: + self.server.obj_delete( + ObjDeleteReq( + project_id=self._project_id(), + object_id=object.name, + digests=[object.digest], + ) + ) + + @trace_sentry.global_trace_sentry.watch() + def delete_op_version(self, op: OpRef) -> None: + self.server.obj_delete( + ObjDeleteReq( + project_id=self._project_id(), + object_id=op.name, + digests=[op.digest], + ) + ) + def get_feedback( self, query: Query | str | None = None, @@ -1648,7 +1740,7 @@ def redact_sensitive_keys(obj: Any) -> Any: if isinstance(obj, dict): dict_res = {} for k, v in obj.items(): - if k in REDACT_KEYS: + if isinstance(k, str) and should_redact(k): dict_res[k] = REDACTED_VALUE else: dict_res[k] = redact_sensitive_keys(v)