diff --git a/.changes/unreleased/Features-20250129-161242.yaml b/.changes/unreleased/Features-20250129-161242.yaml new file mode 100644 index 00000000..29ea3b32 --- /dev/null +++ b/.changes/unreleased/Features-20250129-161242.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Add override and serialization capabilities to record/replay +time: 2025-01-29T16:12:42.914913-05:00 +custom: + Author: peterallenwebb + Issue: "241" diff --git a/dbt_common/record.py b/dbt_common/record.py index d5ddcb31..be54a0b6 100644 --- a/dbt_common/record.py +++ b/dbt_common/record.py @@ -1,20 +1,25 @@ -"""The record module provides a mechanism for recording dbt's interaction with -external systems during a command invocation, so that the command can be re-run -later with the recording 'replayed' to dbt. +"""The record module provides a record/replay mechanism for recording dbt's +interactions with external systems during a command invocation, so that the +command can be re-run later with the recording 'replayed' to dbt. The rationale for and architecture of this module are described in detail in the docs/guides/record_replay.md document in this repository. """ import functools import dataclasses +import inspect import json import os +import threading from enum import Enum -from inspect import getfullargspec, signature, FullArgSpec -from typing import Any, Callable, Dict, List, Mapping, Optional, Type +from typing import Any, Callable, Dict, List, Mapping, Optional, TextIO, Tuple, Type import contextvars +from mashumaro import field_options +from mashumaro.mixins.json import DataClassJSONMixin +from mashumaro.types import SerializationStrategy + RECORDED_BY_HIGHER_FUNCTION = contextvars.ContextVar("RECORDED_BY_HIGHER_FUNCTION", default=False) @@ -133,6 +138,7 @@ class RecorderMode(Enum): class Recorder: _record_cls_by_name: Dict[str, Type] = {} _record_name_by_params_name: Dict[str, str] = {} + _auto_serialization_strategies: Dict[Type, SerializationStrategy] = {} def __init__( self, @@ -193,9 +199,13 @@ def pop_matching_record(self, params: Any) -> Optional[Record]: return match + def write_json(self, out_stream: TextIO): + d = self._to_dict() + json.dump(d, out_stream) + def write(self) -> None: with open(self.current_recording_path, "w") as file: - json.dump(self._to_dict(), file) + self.write_json(file) def _to_dict(self) -> Dict: dct: Dict[str, Any] = {} @@ -209,7 +219,11 @@ def _to_dict(self) -> Dict: @classmethod def load(cls, file_name: str) -> Dict[str, List[Dict[str, Any]]]: with open(file_name) as file: - return json.load(file) + return cls.load_json(file) + + @classmethod + def load_json(cls, in_stream: TextIO) -> Dict[str, List[Dict[str, Any]]]: + return json.load(in_stream) def _ensure_records_processed(self, record_type_name: str) -> None: if record_type_name in self._records_by_type: @@ -243,6 +257,12 @@ def print_diffs(self) -> None: assert self.diff is not None print(repr(self.diff.calculate_diff())) + @classmethod + def register_serialization_strategy( + cls, t: Type, serialization_strategy: SerializationStrategy + ) -> None: + cls._auto_serialization_strategies[t] = serialization_strategy + def get_record_mode_from_env() -> Optional[RecorderMode]: """ @@ -275,7 +295,7 @@ def get_record_types_from_env() -> Optional[List]: If no types are provided, there will be no filtering. Invalid types will be ignored. - Expected format: 'DBT_RECORDER_TYPES=QueryRecord,FileLoadRecord,OtherRecord' + Expected format: 'DBT_RECORDER_TYPES=Database,FileLoadRecord' """ record_types_str = os.environ.get("DBT_RECORDER_TYPES") @@ -287,18 +307,26 @@ def get_record_types_from_env() -> Optional[List]: def get_record_types_from_dict(fp: str) -> List: - """ - Get the record subset from the dict. - """ + """Get the record subset from the dict.""" with open(fp) as file: loaded_dct = json.load(file) return list(loaded_dct.keys()) def auto_record_function( - record_name: str, method: bool = False, group_name: Optional[str] = None + record_name: str, + method: bool = True, + group: Optional[str] = None, + index_on_thread_name: bool = True, ) -> Callable: - return functools.partial(_record_function_inner, record_name, method, False, None) + """This is the @auto_record_function decorator. It works in a similar way to + the @record_function decorator, except automatically generates boilerplate + classes for the Record, Params, and Result classes which would otherwise be + needed. That makes it suitable for quickly adding record support to simple + functions with simple parameters.""" + return functools.partial( + _record_function_inner, record_name, method, False, None, group, index_on_thread_name + ) def record_function( @@ -307,53 +335,87 @@ def record_function( tuple_result: bool = False, id_field_name: Optional[str] = None, ) -> Callable: + """This is the @record_function decorator, which marks functions which will + have their function calls recorded during record mode, and mocked out with + previously recorded replay data during replay.""" return functools.partial( - _record_function_inner, record_type, method, tuple_result, id_field_name + _record_function_inner, record_type, method, tuple_result, id_field_name, None, False ) -def get_arg_fields(spec: FullArgSpec): +def _get_arg_fields( + spec: inspect.FullArgSpec, + skip_first: bool = False, +) -> List[Tuple[str, Optional[Type], dataclasses.Field]]: arg_fields = [] defaults = len(spec.defaults) if spec.defaults else 0 - for i, arg in enumerate(spec.args): - annotation = spec.annotations.get(arg) + for i, arg_name in enumerate(spec.args): + if skip_first and i == 0: + continue + annotation = spec.annotations.get(arg_name) + if annotation is None: + raise Exception("Recorded functions must have type annotations.") + field = _get_field(arg_name, annotation) if i >= len(spec.args) - defaults: - arg_fields.append( - ( - arg, - annotation, - dataclasses.field( - default=spec.defaults[i - len(spec.args) + defaults] - if spec.defaults - else None - ), # type: ignore - ) + field[2].default = ( + spec.defaults[i - len(spec.args) + defaults] if spec.defaults else None ) - else: - arg_fields.append((arg, annotation, None)) - + arg_fields.append(field) return arg_fields -def _record_function_inner(record_type, method, tuple_result, id_field_name, func_to_record): - # To avoid runtime overhead and other unpleasantness, we only apply the - # record/replay decorator if a relevant env var is set. +def _get_field(field_name: str, t: Type) -> Tuple[str, Optional[Type], dataclasses.Field]: + dc_field: dataclasses.Field = dataclasses.field() + strat = Recorder._auto_serialization_strategies.get(t) + if strat is not None: + dc_field.metadata = field_options(serialization_strategy=Recorder._auto_serialization_strategies[t]) # type: ignore + + return field_name, t, dc_field + + +@dataclasses.dataclass +class AutoValues(DataClassJSONMixin): + def _to_dict(self): + return self.to_dict() + + def _from_dict(self, data): + return self.from_dict(data) + + +def _record_function_inner( + record_type, method, tuple_result, id_field_name, group, index_on_thread_id, func_to_record +): + # When record/replay is not active, do nothing. if get_record_mode_from_env() is None: return func_to_record if isinstance(record_type, str): - return_type = signature(func_to_record).return_annotation + return_type = inspect.signature(func_to_record).return_annotation + fields = _get_arg_fields(inspect.getfullargspec(func_to_record), method) + if index_on_thread_id: + id_field_name = "thread_id" + fields.insert(0, _get_field("thread_id", str)) params_cls = dataclasses.make_dataclass( - f"{record_type}Params", get_arg_fields(getfullargspec(func_to_record)) + f"{record_type}Params", fields, bases=(AutoValues,) ) - result_cls = dataclasses.make_dataclass( - f"{record_type}Result", [("return_val", return_type)] + result_cls = ( + None + if return_type is None or return_type == inspect._empty + else dataclasses.make_dataclass( + f"{record_type}Result", + [_get_field("return_val", return_type)], + bases=(AutoValues,), + ) ) record_type = type( - f"{record_type}Record", (Record,), {"params_cls": params_cls, "result_cls": result_cls} + f"{record_type}Record", + (Record,), + {"params_cls": params_cls, "result_cls": result_cls, "group": group}, ) + Recorder.register_record_type(record_type) + @functools.wraps(func_to_record) def record_replay_wrapper(*args, **kwargs) -> Any: recorder: Optional[Recorder] = None @@ -377,7 +439,10 @@ def record_replay_wrapper(*args, **kwargs) -> Any: # params constructor. param_args = args[1:] if method else args if method and id_field_name is not None: - param_args = (getattr(args[0], id_field_name),) + param_args + if index_on_thread_id: + param_args = (threading.current_thread().name,) + param_args + else: + param_args = (getattr(args[0], id_field_name),) + param_args params = record_type.params_cls(*param_args, **kwargs) @@ -406,4 +471,59 @@ def record_replay_wrapper(*args, **kwargs) -> Any: recorder.add_record(record_type(params=params, result=result)) return r + setattr( + record_replay_wrapper, + "_record_metadata", + { + "record_type": record_type, + "method": method, + "tuple_result": tuple_result, + "id_field_name": id_field_name, + "group": group, + "index_on_thread_id": index_on_thread_id, + }, + ) + return record_replay_wrapper + + +def supports_replay(cls): + """Class decorator which adds record/replay support for a class. In particular, + this decorator ensures that calls to overriden functions are still recorded.""" + + # When record/replay is inactive, do nothing. + if get_record_mode_from_env() is None: + return cls + + # Replace the __init_subclass__ method of this class so that when it + # is subclassed, methods on the new subclass which override recorded + # functions are modified to be recorded as well. + original_init_subclass = cls.__init_subclass__ + + @classmethod + def wrapping_init_subclass(sub_cls): + for method_name in dir(cls): + method = getattr(cls, method_name) + metadata = getattr(method, "_record_metadata", None) + if method and getattr(method, "_record_metadata", None): + sub_method = getattr(sub_cls, method_name, None) + if sub_method is not None: + setattr( + sub_cls, + method_name, + _record_function_inner( + metadata["record_type"], + metadata["method"], + metadata["tuple_result"], + metadata["id_field_name"], + metadata["group"], + metadata["index_on_thread_id"], + sub_method, + ), + ) + + original_init_subclass() + + cls.__init_subclass__ = wrapping_init_subclass + + return cls diff --git a/docs/guides/record_replay.md b/docs/guides/record_replay.md index b6dfc7b8..98f6939f 100644 --- a/docs/guides/record_replay.md +++ b/docs/guides/record_replay.md @@ -36,11 +36,11 @@ Record/replay behavior is activated and configured via environment variables. Wh The record/replay subsystem is activated by setting the `DBT_RECORDER_MODE` variable to `replay`, `record`, or `diff`, case insensitive. Invalid values are ignored and do not throw exceptions. -`DBT_RECODER_TYPES` is optional. It indicates which types to filter the results by and expects a list of strings values for the `Record` subclasses or groups of such classes. For example, all records of database/DWH interaction performed by adapters belong to the `Database` group. Any invalid type or group name will be ignored. `all` is a valid value for this variable and has the same effect as not populating the variable. +`DBT_RECORDER_TYPES` is optional. It indicates which types to filter the results by and expects a list of strings values for the `Record` subclasses or groups of such classes. For example, all records of database/DWH interaction performed by adapters belong to the `Database` group. Any invalid type or group name will be ignored. `all` is a valid value for this variable and has the same effect as not populating the variable. ```bash -DBT_RECORDER_MODE=record DBT_RECODER_TYPES=Database dbt run +DBT_RECORDER_MODE=record DBT_RECORDER_TYPES=Database dbt run ``` replay need the file to replay diff --git a/tests/unit/test_record.py b/tests/unit/test_record.py index 8a4656f1..9f2152bb 100644 --- a/tests/unit/test_record.py +++ b/tests/unit/test_record.py @@ -1,10 +1,21 @@ import dataclasses import os +from io import StringIO + import pytest from typing import Optional +from mashumaro.types import SerializationStrategy + from dbt_common.context import set_invocation_context, get_invocation_context -from dbt_common.record import record_function, Record, Recorder, RecorderMode, auto_record_function +from dbt_common.record import ( + record_function, + Record, + Recorder, + RecorderMode, + auto_record_function, + supports_replay, +) @dataclasses.dataclass @@ -207,7 +218,7 @@ def test_auto_decorator_records(setup) -> None: set_invocation_context({}) get_invocation_context().recorder = recorder - @auto_record_function("TestAuto") + @auto_record_function("TestAuto", index_on_thread_name=False, method=False) def test_func(a: int, b: str, c: Optional[str] = None) -> str: return str(a) + b + (c if c else "") @@ -216,3 +227,60 @@ def test_func(a: int, b: str, c: Optional[str] = None) -> str: assert recorder._records_by_type["TestAutoRecord"][-1].params.a == 123 assert recorder._records_by_type["TestAutoRecord"][-1].params.b == "abc" assert recorder._records_by_type["TestAutoRecord"][-1].result.return_val == "123abc" + + +def test_recorded_function_with_override() -> None: + os.environ["DBT_RECORDER_MODE"] = "Record" + recorder = Recorder(RecorderMode.RECORD, None) + set_invocation_context({}) + get_invocation_context().recorder = recorder + + @supports_replay + class Recordable: + @auto_record_function("TestAuto") + def test_func(self, a: int) -> int: + return 2 * a + + class RecordableSubclass(Recordable): + def test_func(self, a: int) -> int: + return 3 * a + + rs = RecordableSubclass() + + rs.test_func(1) + + assert recorder._records_by_type["TestAutoRecord"][-1].params.a == 1 + assert recorder._records_by_type["TestAutoRecord"][-1].result.return_val == 3 + + +class CustomType: + def __init__(self, n: int): + self.value = n + + +class CustomSerializationStrategy(SerializationStrategy): + def serialize(self, obj: CustomType) -> int: + return obj.value + + def deserialize(self, value: int) -> CustomType: + return CustomType(value) + + +def test_recorded_with_custom_serializer() -> None: + os.environ["DBT_RECORDER_MODE"] = "Record" + recorder = Recorder(RecorderMode.RECORD, None) + set_invocation_context({}) + get_invocation_context().recorder = recorder + + recorder.register_serialization_strategy(CustomType, CustomSerializationStrategy()) + + @auto_record_function("TestAuto") + def test_func(a: CustomType) -> CustomType: + return CustomType(a.value * 2) + + test_func(CustomType(21)) + + buffer = StringIO("") + recorder.write_json(buffer) + buffer.seek(0) + recorder.load_json(buffer)