Skip to content

Commit

Permalink
Add record/replay support for overridden methods, type serializers (#241
Browse files Browse the repository at this point in the history
)

* Add record/replay support for overriden methods.

* Add support for mashumaro serializers to record/reply.

* Add changelog entry.
  • Loading branch information
peterallenwebb authored Jan 30, 2025
1 parent 3441e08 commit 3956ae7
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 43 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20250129-161242.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Add override and serialization capabilities to record/replay
time: 2025-01-29T16:12:42.914913-05:00
custom:
Author: peterallenwebb
Issue: "241"
198 changes: 159 additions & 39 deletions dbt_common/record.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,25 @@
"""The record module provides a mechanism for recording dbt's interaction with
external systems during a command invocation, so that the command can be re-run
later with the recording 'replayed' to dbt.
"""The record module provides a record/replay mechanism for recording dbt's
interactions with external systems during a command invocation, so that the
command can be re-run later with the recording 'replayed' to dbt.
The rationale for and architecture of this module are described in detail in the
docs/guides/record_replay.md document in this repository.
"""
import functools
import dataclasses
import inspect
import json
import os
import threading

from enum import Enum
from inspect import getfullargspec, signature, FullArgSpec
from typing import Any, Callable, Dict, List, Mapping, Optional, Type
from typing import Any, Callable, Dict, List, Mapping, Optional, TextIO, Tuple, Type
import contextvars

from mashumaro import field_options
from mashumaro.mixins.json import DataClassJSONMixin
from mashumaro.types import SerializationStrategy

RECORDED_BY_HIGHER_FUNCTION = contextvars.ContextVar("RECORDED_BY_HIGHER_FUNCTION", default=False)


Expand Down Expand Up @@ -133,6 +138,7 @@ class RecorderMode(Enum):
class Recorder:
_record_cls_by_name: Dict[str, Type] = {}
_record_name_by_params_name: Dict[str, str] = {}
_auto_serialization_strategies: Dict[Type, SerializationStrategy] = {}

def __init__(
self,
Expand Down Expand Up @@ -193,9 +199,13 @@ def pop_matching_record(self, params: Any) -> Optional[Record]:

return match

def write_json(self, out_stream: TextIO):
d = self._to_dict()
json.dump(d, out_stream)

def write(self) -> None:
with open(self.current_recording_path, "w") as file:
json.dump(self._to_dict(), file)
self.write_json(file)

def _to_dict(self) -> Dict:
dct: Dict[str, Any] = {}
Expand All @@ -209,7 +219,11 @@ def _to_dict(self) -> Dict:
@classmethod
def load(cls, file_name: str) -> Dict[str, List[Dict[str, Any]]]:
with open(file_name) as file:
return json.load(file)
return cls.load_json(file)

@classmethod
def load_json(cls, in_stream: TextIO) -> Dict[str, List[Dict[str, Any]]]:
return json.load(in_stream)

def _ensure_records_processed(self, record_type_name: str) -> None:
if record_type_name in self._records_by_type:
Expand Down Expand Up @@ -243,6 +257,12 @@ def print_diffs(self) -> None:
assert self.diff is not None
print(repr(self.diff.calculate_diff()))

@classmethod
def register_serialization_strategy(
cls, t: Type, serialization_strategy: SerializationStrategy
) -> None:
cls._auto_serialization_strategies[t] = serialization_strategy


def get_record_mode_from_env() -> Optional[RecorderMode]:
"""
Expand Down Expand Up @@ -275,7 +295,7 @@ def get_record_types_from_env() -> Optional[List]:
If no types are provided, there will be no filtering.
Invalid types will be ignored.
Expected format: 'DBT_RECORDER_TYPES=QueryRecord,FileLoadRecord,OtherRecord'
Expected format: 'DBT_RECORDER_TYPES=Database,FileLoadRecord'
"""
record_types_str = os.environ.get("DBT_RECORDER_TYPES")

Expand All @@ -287,18 +307,26 @@ def get_record_types_from_env() -> Optional[List]:


def get_record_types_from_dict(fp: str) -> List:
"""
Get the record subset from the dict.
"""
"""Get the record subset from the dict."""
with open(fp) as file:
loaded_dct = json.load(file)
return list(loaded_dct.keys())


def auto_record_function(
record_name: str, method: bool = False, group_name: Optional[str] = None
record_name: str,
method: bool = True,
group: Optional[str] = None,
index_on_thread_name: bool = True,
) -> Callable:
return functools.partial(_record_function_inner, record_name, method, False, None)
"""This is the @auto_record_function decorator. It works in a similar way to
the @record_function decorator, except automatically generates boilerplate
classes for the Record, Params, and Result classes which would otherwise be
needed. That makes it suitable for quickly adding record support to simple
functions with simple parameters."""
return functools.partial(
_record_function_inner, record_name, method, False, None, group, index_on_thread_name
)


def record_function(
Expand All @@ -307,53 +335,87 @@ def record_function(
tuple_result: bool = False,
id_field_name: Optional[str] = None,
) -> Callable:
"""This is the @record_function decorator, which marks functions which will
have their function calls recorded during record mode, and mocked out with
previously recorded replay data during replay."""
return functools.partial(
_record_function_inner, record_type, method, tuple_result, id_field_name
_record_function_inner, record_type, method, tuple_result, id_field_name, None, False
)


def get_arg_fields(spec: FullArgSpec):
def _get_arg_fields(
spec: inspect.FullArgSpec,
skip_first: bool = False,
) -> List[Tuple[str, Optional[Type], dataclasses.Field]]:
arg_fields = []
defaults = len(spec.defaults) if spec.defaults else 0
for i, arg in enumerate(spec.args):
annotation = spec.annotations.get(arg)
for i, arg_name in enumerate(spec.args):
if skip_first and i == 0:
continue
annotation = spec.annotations.get(arg_name)
if annotation is None:
raise Exception("Recorded functions must have type annotations.")
field = _get_field(arg_name, annotation)
if i >= len(spec.args) - defaults:
arg_fields.append(
(
arg,
annotation,
dataclasses.field(
default=spec.defaults[i - len(spec.args) + defaults]
if spec.defaults
else None
), # type: ignore
)
field[2].default = (
spec.defaults[i - len(spec.args) + defaults] if spec.defaults else None
)
else:
arg_fields.append((arg, annotation, None))

arg_fields.append(field)
return arg_fields


def _record_function_inner(record_type, method, tuple_result, id_field_name, func_to_record):
# To avoid runtime overhead and other unpleasantness, we only apply the
# record/replay decorator if a relevant env var is set.
def _get_field(field_name: str, t: Type) -> Tuple[str, Optional[Type], dataclasses.Field]:
dc_field: dataclasses.Field = dataclasses.field()
strat = Recorder._auto_serialization_strategies.get(t)
if strat is not None:
dc_field.metadata = field_options(serialization_strategy=Recorder._auto_serialization_strategies[t]) # type: ignore

return field_name, t, dc_field


@dataclasses.dataclass
class AutoValues(DataClassJSONMixin):
def _to_dict(self):
return self.to_dict()

def _from_dict(self, data):
return self.from_dict(data)


def _record_function_inner(
record_type, method, tuple_result, id_field_name, group, index_on_thread_id, func_to_record
):
# When record/replay is not active, do nothing.
if get_record_mode_from_env() is None:
return func_to_record

if isinstance(record_type, str):
return_type = signature(func_to_record).return_annotation
return_type = inspect.signature(func_to_record).return_annotation
fields = _get_arg_fields(inspect.getfullargspec(func_to_record), method)
if index_on_thread_id:
id_field_name = "thread_id"
fields.insert(0, _get_field("thread_id", str))
params_cls = dataclasses.make_dataclass(
f"{record_type}Params", get_arg_fields(getfullargspec(func_to_record))
f"{record_type}Params", fields, bases=(AutoValues,)
)
result_cls = dataclasses.make_dataclass(
f"{record_type}Result", [("return_val", return_type)]
result_cls = (
None
if return_type is None or return_type == inspect._empty
else dataclasses.make_dataclass(
f"{record_type}Result",
[_get_field("return_val", return_type)],
bases=(AutoValues,),
)
)

record_type = type(
f"{record_type}Record", (Record,), {"params_cls": params_cls, "result_cls": result_cls}
f"{record_type}Record",
(Record,),
{"params_cls": params_cls, "result_cls": result_cls, "group": group},
)

Recorder.register_record_type(record_type)

@functools.wraps(func_to_record)
def record_replay_wrapper(*args, **kwargs) -> Any:
recorder: Optional[Recorder] = None
Expand All @@ -377,7 +439,10 @@ def record_replay_wrapper(*args, **kwargs) -> Any:
# params constructor.
param_args = args[1:] if method else args
if method and id_field_name is not None:
param_args = (getattr(args[0], id_field_name),) + param_args
if index_on_thread_id:
param_args = (threading.current_thread().name,) + param_args
else:
param_args = (getattr(args[0], id_field_name),) + param_args

params = record_type.params_cls(*param_args, **kwargs)

Expand Down Expand Up @@ -406,4 +471,59 @@ def record_replay_wrapper(*args, **kwargs) -> Any:
recorder.add_record(record_type(params=params, result=result))
return r

setattr(
record_replay_wrapper,
"_record_metadata",
{
"record_type": record_type,
"method": method,
"tuple_result": tuple_result,
"id_field_name": id_field_name,
"group": group,
"index_on_thread_id": index_on_thread_id,
},
)

return record_replay_wrapper


def supports_replay(cls):
"""Class decorator which adds record/replay support for a class. In particular,
this decorator ensures that calls to overriden functions are still recorded."""

# When record/replay is inactive, do nothing.
if get_record_mode_from_env() is None:
return cls

# Replace the __init_subclass__ method of this class so that when it
# is subclassed, methods on the new subclass which override recorded
# functions are modified to be recorded as well.
original_init_subclass = cls.__init_subclass__

@classmethod
def wrapping_init_subclass(sub_cls):
for method_name in dir(cls):
method = getattr(cls, method_name)
metadata = getattr(method, "_record_metadata", None)
if method and getattr(method, "_record_metadata", None):
sub_method = getattr(sub_cls, method_name, None)
if sub_method is not None:
setattr(
sub_cls,
method_name,
_record_function_inner(
metadata["record_type"],
metadata["method"],
metadata["tuple_result"],
metadata["id_field_name"],
metadata["group"],
metadata["index_on_thread_id"],
sub_method,
),
)

original_init_subclass()

cls.__init_subclass__ = wrapping_init_subclass

return cls
4 changes: 2 additions & 2 deletions docs/guides/record_replay.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@ Record/replay behavior is activated and configured via environment variables. Wh

The record/replay subsystem is activated by setting the `DBT_RECORDER_MODE` variable to `replay`, `record`, or `diff`, case insensitive. Invalid values are ignored and do not throw exceptions.

`DBT_RECODER_TYPES` is optional. It indicates which types to filter the results by and expects a list of strings values for the `Record` subclasses or groups of such classes. For example, all records of database/DWH interaction performed by adapters belong to the `Database` group. Any invalid type or group name will be ignored. `all` is a valid value for this variable and has the same effect as not populating the variable.
`DBT_RECORDER_TYPES` is optional. It indicates which types to filter the results by and expects a list of strings values for the `Record` subclasses or groups of such classes. For example, all records of database/DWH interaction performed by adapters belong to the `Database` group. Any invalid type or group name will be ignored. `all` is a valid value for this variable and has the same effect as not populating the variable.


```bash
DBT_RECORDER_MODE=record DBT_RECODER_TYPES=Database dbt run
DBT_RECORDER_MODE=record DBT_RECORDER_TYPES=Database dbt run
```

replay need the file to replay
Expand Down
Loading

0 comments on commit 3956ae7

Please sign in to comment.