From 3a20cdc98d9bdb2e8ba58e0eb0cfdd6d503109b0 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Fri, 13 Dec 2024 17:30:58 -0500 Subject: [PATCH 1/5] Move some code from dbt-core to dbt-common. --- dbt_common/clients/jinja.py | 2 + dbt_common/clients/jinja_macro_call.py | 101 +++++++++++++++++++++++++ tests/unit/test_jinja_macro_call.py | 57 ++++++++++++++ 3 files changed, 160 insertions(+) create mode 100644 dbt_common/clients/jinja_macro_call.py create mode 100644 tests/unit/test_jinja_macro_call.py diff --git a/dbt_common/clients/jinja.py b/dbt_common/clients/jinja.py index 4bba253e..44c87dde 100644 --- a/dbt_common/clients/jinja.py +++ b/dbt_common/clients/jinja.py @@ -18,10 +18,12 @@ Optional, Union, Set, + Tuple, Type, NoReturn, ) +from hypothesis.errors import Frozen from typing_extensions import Protocol import jinja2 diff --git a/dbt_common/clients/jinja_macro_call.py b/dbt_common/clients/jinja_macro_call.py new file mode 100644 index 00000000..188e1c62 --- /dev/null +++ b/dbt_common/clients/jinja_macro_call.py @@ -0,0 +1,101 @@ +import dataclasses +from typing import Any, Dict, List, Optional + +import jinja2 + +from dbt_common.clients.jinja import get_environment, MacroType + +PRIMITIVE_TYPES = ["Any", "bool", "float", "int", "str"] + + +@dataclasses.dataclass +class TypeCheckFailure: + msg: str + + +@dataclasses.dataclass +class DbtMacroCall: + """An instance of this class represents a jinja macro call in a template + for the purposes of recording information for type checking.""" + + name: str + source: str + arg_types: List[Optional[MacroType]] = dataclasses.field(default_factory=list) + kwarg_types: Dict[str, Optional[MacroType]] = dataclasses.field(default_factory=dict) + + @classmethod + def from_call(cls, call: jinja2.nodes.Call, name: str) -> "DbtMacroCall": + dbt_call = cls(name, "") + for arg in call.args: # type: ignore + dbt_call.arg_types.append(cls.get_type(arg)) + for arg in call.kwargs: # type: ignore + dbt_call.kwarg_types[arg.key] = cls.get_type(arg.value) + return dbt_call + + @classmethod + def get_type(cls, param: Any) -> Optional[MacroType]: + if isinstance(param, jinja2.nodes.Name): + return None # TODO: infer types from variable names + + if isinstance(param, jinja2.nodes.Call): + return None # TODO: infer types from function/macro calls + + if isinstance(param, jinja2.nodes.Getattr): + return None # TODO: infer types from . operator + + if isinstance(param, jinja2.nodes.Concat): + return None + + if isinstance(param, jinja2.nodes.Const): + if isinstance(param.value, str): # type: ignore + return MacroType("str") + elif isinstance(param.value, bool): # type: ignore + return MacroType("bool") + elif isinstance(param.value, int): # type: ignore + return MacroType("int") + elif isinstance(param.value, float): # type: ignore + return MacroType("float") + elif param.value is None: # type: ignore + return None + else: + return None + + if isinstance(param, jinja2.nodes.Dict): + return None + + return None + + def is_valid_type(self, t: MacroType) -> bool: + if len(t.type_params) == 0 and t.name in PRIMITIVE_TYPES: + return True + elif ( + t.name == "Dict" + and len(t.type_params) == 2 + and t.type_params[0].name in PRIMITIVE_TYPES + and self.is_valid_type(t.type_params[1]) + ): + return True + elif ( + t.name in ["List", "Optional"] + and len(t.type_params) == 1 + and self.is_valid_type(t.type_params[0]) + ): + return True + + return False + + def check(self, macro_text: str) -> List[TypeCheckFailure]: + failures: List[TypeCheckFailure] = [] + template = get_environment(None, capture_macros=True).parse(macro_text) + jinja_macro = template.body[0] + + for arg_type in jinja_macro.arg_types: + if not self.is_valid_type(arg_type): + failures.append(TypeCheckFailure(msg="Invalid type.")) + + for i, arg_type in enumerate(self.arg_types): + expected_type = jinja_macro.arg_types[i] + if arg_type != expected_type: + failures.append(TypeCheckFailure(msg="Wrong type of parameter.")) + + return failures diff --git a/tests/unit/test_jinja_macro_call.py b/tests/unit/test_jinja_macro_call.py new file mode 100644 index 00000000..4b4301f1 --- /dev/null +++ b/tests/unit/test_jinja_macro_call.py @@ -0,0 +1,57 @@ +from dbt_common.clients.jinja import MacroType +from dbt_common.clients.jinja_macro_call import PRIMITIVE_TYPES, DbtMacroCall + + +single_param_macro_text = """{% macro call_me(param: TYPE) %} +{% endmacro %}""" + + +def test_primitive_type_checks() -> None: + for type_name in PRIMITIVE_TYPES: + macro_text = single_param_macro_text.replace("TYPE", type_name) + call = DbtMacroCall("call_me", "call_me", [MacroType(type_name, [])], {}) + assert not any(call.check(macro_text)) + + +def test_primitive_type_checks_wrong() -> None: + for type_name in PRIMITIVE_TYPES: + macro_text = single_param_macro_text.replace("TYPE", type_name) + wrong_type = next(t for t in PRIMITIVE_TYPES if t != type_name) + call = DbtMacroCall("call_me", "call_me", [MacroType(wrong_type, [])], {}) + assert any(call.check(macro_text)) + + +def test_list_type_checks() -> None: + for type_name in PRIMITIVE_TYPES: + macro_text = single_param_macro_text.replace("TYPE", f"List[{type_name}]") + expected_type = MacroType("List", [MacroType(type_name)]) + call = DbtMacroCall("call_me", "call_me", [expected_type], {}) + assert not any(call.check(macro_text)) + + +def test_dict_type_checks() -> None: + for type_name in PRIMITIVE_TYPES: + macro_text = single_param_macro_text.replace("TYPE", f"Dict[{type_name}, {type_name}]") + expected_type = MacroType("Dict", [MacroType(type_name), MacroType(type_name)]) + call = DbtMacroCall("call_me", "call_me", [expected_type], {}) + assert not any(call.check(macro_text)) + + +def test_too_few_args() -> None: + macro_text = "{% macro call_me(one: str, two: str, three: str) %}" + + +def test_too_many_args() -> None: + pass + + +kwarg_param_macro_text = """{% macro call_me(param: int = 10, arg_one = "val1", arg_two: int = 2, arg_three: str = "val3" ) %} +{% endmacro %}""" + + +# Better structured exceptions +# Test detection of macro called with too few positional args +# Test detection of macro called with too many positional args +# Test detection of macro called with keyword arg having wrong type +# Test detection of macro called with non-existent keyword arg +# Test detection of macro with invalid default value for param type From 130c2d95da48a3d7248497e979caee7794202221 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Fri, 13 Dec 2024 19:11:57 -0500 Subject: [PATCH 2/5] Refine type-checking logic and tests. --- dbt_common/clients/jinja.py | 2 - dbt_common/clients/jinja_macro_call.py | 57 ++++++++++++++++---------- tests/unit/test_jinja_macro_call.py | 20 ++++----- 3 files changed, 43 insertions(+), 36 deletions(-) diff --git a/dbt_common/clients/jinja.py b/dbt_common/clients/jinja.py index 44c87dde..4bba253e 100644 --- a/dbt_common/clients/jinja.py +++ b/dbt_common/clients/jinja.py @@ -18,12 +18,10 @@ Optional, Union, Set, - Tuple, Type, NoReturn, ) -from hypothesis.errors import Frozen from typing_extensions import Protocol import jinja2 diff --git a/dbt_common/clients/jinja_macro_call.py b/dbt_common/clients/jinja_macro_call.py index 188e1c62..2b6964b7 100644 --- a/dbt_common/clients/jinja_macro_call.py +++ b/dbt_common/clients/jinja_macro_call.py @@ -1,4 +1,5 @@ import dataclasses +from enum import Enum from typing import Any, Dict, List, Optional import jinja2 @@ -7,12 +8,16 @@ PRIMITIVE_TYPES = ["Any", "bool", "float", "int", "str"] +class FailureType(Enum): + TYPE_MISMATCH = "mismatch" + UNKNOWN_TYPE = "unknown" + PARAMETER_COUNT = "param_count" @dataclasses.dataclass class TypeCheckFailure: + type: FailureType msg: str - @dataclasses.dataclass class DbtMacroCall: """An instance of this class represents a jinja macro call in a template @@ -65,37 +70,47 @@ def get_type(cls, param: Any) -> Optional[MacroType]: return None - def is_valid_type(self, t: MacroType) -> bool: + def check_type(self, t: MacroType) -> List[TypeCheckFailure]: if len(t.type_params) == 0 and t.name in PRIMITIVE_TYPES: - return True - elif ( - t.name == "Dict" - and len(t.type_params) == 2 - and t.type_params[0].name in PRIMITIVE_TYPES - and self.is_valid_type(t.type_params[1]) - ): - return True - elif ( - t.name in ["List", "Optional"] - and len(t.type_params) == 1 - and self.is_valid_type(t.type_params[0]) - ): - return True - - return False + return [] + + failures: List[TypeCheckFailure] = [] + if t.name == "Dict": + if len(t.type_params) != 2: + failures.append(TypeCheckFailure(FailureType.PARAMETER_COUNT, f"Expected two type parameters for Dict[], found {len(t.type_params)}.")) + else: + if t.type_params[0].name not in PRIMITIVE_TYPES: + failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, "First type parameter of Dict[] must be a primitive type.")) + + failures.extend(self.check_type(t.type_params[1])) + elif t.name in ("List", "Optional"): + if len(t.type_params) != 1: + failures.append(TypeCheckFailure(FailureType.PARAMETER_COUNT, "Expected one type parameter for {t.name}[], found {len(t.type_params)}.")) + + failures.extend(self.check_type(t.type_params[0])) + + return failures def check(self, macro_text: str) -> List[TypeCheckFailure]: failures: List[TypeCheckFailure] = [] template = get_environment(None, capture_macros=True).parse(macro_text) jinja_macro = template.body[0] + # This could be arguably be done elsewhere, but check that every + # parameter passed to the macro has a valid type. for arg_type in jinja_macro.arg_types: - if not self.is_valid_type(arg_type): - failures.append(TypeCheckFailure(msg="Invalid type.")) + failures = self.check_type(arg_type) + if failures: + failures.extend(failures) + # Check that each positional argument matches the type of the for i, arg_type in enumerate(self.arg_types): expected_type = jinja_macro.arg_types[i] if arg_type != expected_type: - failures.append(TypeCheckFailure(msg="Wrong type of parameter.")) + failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, f"Expected type {expected_type.name} as argument {i} but found {arg_type.name}/")) + + # Check whether there were more positional arguments than expected. + if len(self.arg_types) > len(jinja_macro.arg_types): + failures.append(TypeCheckFailure(FailureType.PARAMETER_COUNT, f"Expected {len(self.arg_types)} type arguments, got {len(jinja_macro.arg_types)}.")) return failures diff --git a/tests/unit/test_jinja_macro_call.py b/tests/unit/test_jinja_macro_call.py index 4b4301f1..24b7cfc4 100644 --- a/tests/unit/test_jinja_macro_call.py +++ b/tests/unit/test_jinja_macro_call.py @@ -1,6 +1,5 @@ from dbt_common.clients.jinja import MacroType -from dbt_common.clients.jinja_macro_call import PRIMITIVE_TYPES, DbtMacroCall - +from dbt_common.clients.jinja_macro_call import PRIMITIVE_TYPES, DbtMacroCall, FailureType single_param_macro_text = """{% macro call_me(param: TYPE) %} {% endmacro %}""" @@ -10,7 +9,8 @@ def test_primitive_type_checks() -> None: for type_name in PRIMITIVE_TYPES: macro_text = single_param_macro_text.replace("TYPE", type_name) call = DbtMacroCall("call_me", "call_me", [MacroType(type_name, [])], {}) - assert not any(call.check(macro_text)) + failures = call.check(macro_text) + assert not failures def test_primitive_type_checks_wrong() -> None: @@ -18,7 +18,8 @@ def test_primitive_type_checks_wrong() -> None: macro_text = single_param_macro_text.replace("TYPE", type_name) wrong_type = next(t for t in PRIMITIVE_TYPES if t != type_name) call = DbtMacroCall("call_me", "call_me", [MacroType(wrong_type, [])], {}) - assert any(call.check(macro_text)) + failures = call.check(macro_text) + assert len([f for f in failures if f.type == FailureType.TYPE_MISMATCH]) == 1 def test_list_type_checks() -> None: @@ -26,7 +27,8 @@ def test_list_type_checks() -> None: macro_text = single_param_macro_text.replace("TYPE", f"List[{type_name}]") expected_type = MacroType("List", [MacroType(type_name)]) call = DbtMacroCall("call_me", "call_me", [expected_type], {}) - assert not any(call.check(macro_text)) + failures = call.check(macro_text) + assert not failures def test_dict_type_checks() -> None: @@ -37,14 +39,6 @@ def test_dict_type_checks() -> None: assert not any(call.check(macro_text)) -def test_too_few_args() -> None: - macro_text = "{% macro call_me(one: str, two: str, three: str) %}" - - -def test_too_many_args() -> None: - pass - - kwarg_param_macro_text = """{% macro call_me(param: int = 10, arg_one = "val1", arg_two: int = 2, arg_three: str = "val3" ) %} {% endmacro %}""" From cf3617cbc10aa573de79ad415dcd189eb4200ac6 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Mon, 13 Jan 2025 15:51:26 -0500 Subject: [PATCH 3/5] Add built-in types for "dbt Classes" and refine type checking implementation --- dbt_common/clients/jinja.py | 4 +- dbt_common/clients/jinja_macro_call.py | 184 ++++++++++++++++++------- tests/unit/test_jinja_macro_call.py | 71 ++++++++-- 3 files changed, 190 insertions(+), 69 deletions(-) diff --git a/dbt_common/clients/jinja.py b/dbt_common/clients/jinja.py index 4bba253e..4e5f0dac 100644 --- a/dbt_common/clients/jinja.py +++ b/dbt_common/clients/jinja.py @@ -131,12 +131,12 @@ def parse_signature(self, node: Union[jinja2.nodes.Macro, jinja2.nodes.CallBlock arg = self.parse_assign_target(name_only=True) arg.set_ctx("param") - type_name: Optional[str] + type_name: Optional[MacroType] if self.stream.skip_if("colon"): node.has_type_annotations = True # type: ignore type_name = self.parse_type_name() else: - type_name = "" + type_name = None node.arg_types.append(type_name) # type: ignore diff --git a/dbt_common/clients/jinja_macro_call.py b/dbt_common/clients/jinja_macro_call.py index 2b6964b7..cd215418 100644 --- a/dbt_common/clients/jinja_macro_call.py +++ b/dbt_common/clients/jinja_macro_call.py @@ -1,17 +1,22 @@ import dataclasses from enum import Enum -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Iterable import jinja2 +import jinja2.nodes from dbt_common.clients.jinja import get_environment, MacroType PRIMITIVE_TYPES = ["Any", "bool", "float", "int", "str"] +DBT_CLASSES = ["Column", "Relation", "Result"] + class FailureType(Enum): - TYPE_MISMATCH = "mismatch" - UNKNOWN_TYPE = "unknown" + TYPE_MISMATCH = "type_mismatch" + UNKNOWN_TYPE = "unknown_type" PARAMETER_COUNT = "param_count" + EXTRA_ARGUMENT = "extra_arg" + MISSING_ARGUMENT = "missing_arg" @dataclasses.dataclass class TypeCheckFailure: @@ -19,7 +24,7 @@ class TypeCheckFailure: msg: str @dataclasses.dataclass -class DbtMacroCall: +class MacroCallChecker: """An instance of this class represents a jinja macro call in a template for the purposes of recording information for type checking.""" @@ -29,16 +34,134 @@ class DbtMacroCall: kwarg_types: Dict[str, Optional[MacroType]] = dataclasses.field(default_factory=dict) @classmethod - def from_call(cls, call: jinja2.nodes.Call, name: str) -> "DbtMacroCall": + def from_call(cls, call: jinja2.nodes.Call, name: str) -> "MacroCallChecker": dbt_call = cls(name, "") for arg in call.args: # type: ignore - dbt_call.arg_types.append(cls.get_type(arg)) + dbt_call.arg_types.append(TypeChecker.get_type(arg)) for arg in call.kwargs: # type: ignore - dbt_call.kwarg_types[arg.key] = cls.get_type(arg.value) + dbt_call.kwarg_types[arg.key] = TypeChecker.get_type(arg.value) return dbt_call + def check(self, macro_text: str) -> List[TypeCheckFailure]: + failures: List[TypeCheckFailure] = [] + + macro_checker = MacroChecker.from_jinja(macro_text) + + unassigned_args = list(macro_checker.args) + + # Each positional argument in this call should correspond to an expected + # positional argument with a compatible type. + for i, arg_type in enumerate(self.arg_types): + target_name = macro_checker.args[i] + target_type = macro_checker.arg_types[i] + unassigned_args.remove(target_name) + if arg_type is not None and target_type is not None and arg_type != target_type: + failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, f"Expected type {target_type.name} for argument {target_name} but found {arg_type.name}/")) + + # Each keyword argument in this call should correspond to an expected + # argument that has not already been assigned, and have a compatible type. + for arg_name, arg_type in self.kwarg_types.items(): + if arg_name not in macro_checker.args: + failures.append(TypeCheckFailure(FailureType.EXTRA_ARGUMENT, f"Unexpected keyword argument {arg_name}.")) + elif arg_name not in unassigned_args: + failures.append(TypeCheckFailure(FailureType.EXTRA_ARGUMENT, f"Argument {arg_name} was specified more than once.")) + else: + unassigned_args.remove(arg_name) + expected_type = macro_checker.get_arg_type(arg_name) + if arg_type is not None and expected_type is not None and arg_type != expected_type: + failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, f"Expected type {expected_type.name} for argument {arg_name} but found {arg_type.name}/")) + + # Any remaining unassigned parameters must have a default. + for arg_name in unassigned_args: + if not macro_checker.has_default(arg_name): + failures.append(TypeCheckFailure(FailureType.MISSING_ARGUMENT, f"Missing argument {arg_name}.")) + + # Check that any arguments specified by keyword have the correct type + for arg_name, arg_type in self.kwarg_types.items(): + expected_type = macro_checker.get_arg_type(arg_name) + if arg_type is not None and expected_type is not None and arg_type != expected_type: + failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, f"Expected type {expected_type.name} as argument {arg_name} but found {arg_type.name}/")) + + return failures + + +@dataclasses.dataclass +class MacroChecker: + _jinja_macro: jinja2.nodes.Macro + + @property + def args(self) -> List[str]: + return [a.name for a in self._jinja_macro.args] + + @property + def arg_types(self) -> List[Optional[MacroType]]: + return self._jinja_macro.arg_types # type: ignore + + @property + def defaults(self) -> List[str]: + return self._jinja_macro.defaults + + def get_arg_type(self, arg_name: str) -> Optional[MacroType]: + args = self.args + if arg_name not in args: + return None + else: + return self.arg_types[args.index(arg_name)] + + def has_default(self, arg_name: str) -> bool: + args = self.args + return args.index(arg_name) >= len(self.args) - len(self.defaults) + @classmethod - def get_type(cls, param: Any) -> Optional[MacroType]: + def from_jinja(cls, jinja_text: str) -> "MacroChecker": + template = get_environment(None, capture_macros=True).parse(jinja_text) + jinja_macro = template.body[0] + + if not isinstance(jinja_macro, jinja2.nodes.Macro): + raise Exception("Expected jinja macro.") + + return MacroChecker(jinja_macro) + + def type_check(self) -> List[TypeCheckFailure]: + # Every annotated parameter of the macro being called must have a valid + # type. + failures: List[TypeCheckFailure] = [] + for arg_type in self._jinja_macro.arg_types: # type: ignore + failures = TypeChecker.check(arg_type) + if failures: + failures.extend(failures) + + return failures + + +class TypeChecker: + @staticmethod + def check(t: Optional[MacroType]) -> List[TypeCheckFailure]: + if t is None or len(t.type_params) == 0 and t.name in (PRIMITIVE_TYPES + DBT_CLASSES): + return [] + + failures: List[TypeCheckFailure] = [] + if t.name == "Dict": + if len(t.type_params) != 2: + failures.append(TypeCheckFailure(FailureType.PARAMETER_COUNT, f"Expected two type parameters for Dict[], found {len(t.type_params)}.")) + else: + if t.type_params[0].name not in PRIMITIVE_TYPES: + failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, "First type parameter of Dict[] must be a primitive type.")) + + failures.extend(TypeChecker.check(t.type_params[1])) + elif t.name in ("List", "Optional"): + if len(t.type_params) != 1: + failures.append(TypeCheckFailure(FailureType.PARAMETER_COUNT, f"Expected one type parameter for {t.name}[], found {len(t.type_params)}.")) + + failures.extend(TypeChecker.check(t.type_params[0])) + else: + failures.append(TypeCheckFailure(FailureType.UNKNOWN_TYPE, f"Unknown type {t.name} encountered.")) + + return failures + + + @staticmethod + def get_type(param: Any) -> Optional[MacroType]: if isinstance(param, jinja2.nodes.Name): return None # TODO: infer types from variable names @@ -69,48 +192,3 @@ def get_type(cls, param: Any) -> Optional[MacroType]: return None return None - - def check_type(self, t: MacroType) -> List[TypeCheckFailure]: - if len(t.type_params) == 0 and t.name in PRIMITIVE_TYPES: - return [] - - failures: List[TypeCheckFailure] = [] - if t.name == "Dict": - if len(t.type_params) != 2: - failures.append(TypeCheckFailure(FailureType.PARAMETER_COUNT, f"Expected two type parameters for Dict[], found {len(t.type_params)}.")) - else: - if t.type_params[0].name not in PRIMITIVE_TYPES: - failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, "First type parameter of Dict[] must be a primitive type.")) - - failures.extend(self.check_type(t.type_params[1])) - elif t.name in ("List", "Optional"): - if len(t.type_params) != 1: - failures.append(TypeCheckFailure(FailureType.PARAMETER_COUNT, "Expected one type parameter for {t.name}[], found {len(t.type_params)}.")) - - failures.extend(self.check_type(t.type_params[0])) - - return failures - - def check(self, macro_text: str) -> List[TypeCheckFailure]: - failures: List[TypeCheckFailure] = [] - template = get_environment(None, capture_macros=True).parse(macro_text) - jinja_macro = template.body[0] - - # This could be arguably be done elsewhere, but check that every - # parameter passed to the macro has a valid type. - for arg_type in jinja_macro.arg_types: - failures = self.check_type(arg_type) - if failures: - failures.extend(failures) - - # Check that each positional argument matches the type of the - for i, arg_type in enumerate(self.arg_types): - expected_type = jinja_macro.arg_types[i] - if arg_type != expected_type: - failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, f"Expected type {expected_type.name} as argument {i} but found {arg_type.name}/")) - - # Check whether there were more positional arguments than expected. - if len(self.arg_types) > len(jinja_macro.arg_types): - failures.append(TypeCheckFailure(FailureType.PARAMETER_COUNT, f"Expected {len(self.arg_types)} type arguments, got {len(jinja_macro.arg_types)}.")) - - return failures diff --git a/tests/unit/test_jinja_macro_call.py b/tests/unit/test_jinja_macro_call.py index 24b7cfc4..4bd12cf9 100644 --- a/tests/unit/test_jinja_macro_call.py +++ b/tests/unit/test_jinja_macro_call.py @@ -1,23 +1,34 @@ from dbt_common.clients.jinja import MacroType -from dbt_common.clients.jinja_macro_call import PRIMITIVE_TYPES, DbtMacroCall, FailureType +from dbt_common.clients.jinja_macro_call import PRIMITIVE_TYPES, DBT_CLASSES, FailureType, MacroCallChecker, MacroChecker single_param_macro_text = """{% macro call_me(param: TYPE) %} {% endmacro %}""" def test_primitive_type_checks() -> None: + """Test that primitive types can all be used to annotate macro parameters.""" for type_name in PRIMITIVE_TYPES: macro_text = single_param_macro_text.replace("TYPE", type_name) - call = DbtMacroCall("call_me", "call_me", [MacroType(type_name, [])], {}) + call = MacroCallChecker("call_me", "call_me", [MacroType(type_name, [])], {}) failures = call.check(macro_text) assert not failures -def test_primitive_type_checks_wrong() -> None: - for type_name in PRIMITIVE_TYPES: +def test_dbt_class_type_checks() -> None: + """Test that 'dbt Classes' like Relation, Column, and Result can all be used + to annotate macro parameters.""" + for type_name in DBT_CLASSES: + macro_text = single_param_macro_text.replace("TYPE", type_name) + call = MacroCallChecker("call_me", "call_me", [MacroType(type_name, [])], {}) + failures = call.check(macro_text) + assert not failures + +def test_type_checks_wrong() -> None: + """Test that calls to annotated macros with incorrect types fail type checks.""" + for type_name in PRIMITIVE_TYPES + DBT_CLASSES: macro_text = single_param_macro_text.replace("TYPE", type_name) wrong_type = next(t for t in PRIMITIVE_TYPES if t != type_name) - call = DbtMacroCall("call_me", "call_me", [MacroType(wrong_type, [])], {}) + call = MacroCallChecker("call_me", "call_me", [MacroType(wrong_type, [])], {}) failures = call.check(macro_text) assert len([f for f in failures if f.type == FailureType.TYPE_MISMATCH]) == 1 @@ -26,7 +37,7 @@ def test_list_type_checks() -> None: for type_name in PRIMITIVE_TYPES: macro_text = single_param_macro_text.replace("TYPE", f"List[{type_name}]") expected_type = MacroType("List", [MacroType(type_name)]) - call = DbtMacroCall("call_me", "call_me", [expected_type], {}) + call = MacroCallChecker("call_me", "call_me", [expected_type], {}) failures = call.check(macro_text) assert not failures @@ -35,17 +46,49 @@ def test_dict_type_checks() -> None: for type_name in PRIMITIVE_TYPES: macro_text = single_param_macro_text.replace("TYPE", f"Dict[{type_name}, {type_name}]") expected_type = MacroType("Dict", [MacroType(type_name), MacroType(type_name)]) - call = DbtMacroCall("call_me", "call_me", [expected_type], {}) + call = MacroCallChecker("call_me", "call_me", [expected_type], {}) assert not any(call.check(macro_text)) -kwarg_param_macro_text = """{% macro call_me(param: int = 10, arg_one = "val1", arg_two: int = 2, arg_three: str = "val3" ) %} +kwarg_param_macro_text = """{% macro call_me(arg1: int, arg2: int, arg3: str = "val3", arg4: int = 4, arg5: str = "val5") %} {% endmacro %}""" -# Better structured exceptions -# Test detection of macro called with too few positional args -# Test detection of macro called with too many positional args -# Test detection of macro called with keyword arg having wrong type -# Test detection of macro called with non-existent keyword arg -# Test detection of macro with invalid default value for param type +def test_too_few_pos_args() -> None: + call = MacroCallChecker("call_me", "", [MacroType("int")]) + failures = call.check(kwarg_param_macro_text) + assert len(failures) == 1 + assert failures[0].type == FailureType.MISSING_ARGUMENT + + +def test_unknown_kwarg() -> None: + call = MacroCallChecker("call_me", "", [MacroType("int"), MacroType("int")], {"unk": MacroType("str")}) + failures = call.check(kwarg_param_macro_text) + assert len(failures) == 1 + assert failures[0].type == FailureType.EXTRA_ARGUMENT + + +def test_kwarg_type() -> None: + """Test that annotated kwargs pass type checks when used by name.""" + call = MacroCallChecker("call_me", "", [MacroType("int"), MacroType("int")], {"arg3": MacroType("str")}) + failures = call.check(kwarg_param_macro_text) + assert not failures + + +def test_wrong_kwarg_type() -> None: + """Test that annotated kwargs pass type checks fail when the wrong type is used.""" + call = MacroCallChecker("call_me", "", [], {"arg3": MacroType("int")}) + failures = call.check(kwarg_param_macro_text) + assert failures[0].type == FailureType.TYPE_MISMATCH + +# TODO: Test detection of macro with invalid default value for param type +# TODO: Test detection of macro called with invalid variable parameter, as known from macro parameter annotation. + + +def test_unknown_type_check() -> None: + """Test that macro parameter annotations with unknown types fail type checks.""" + macro_text = single_param_macro_text.replace("TYPE", "Invalid") + checker = MacroChecker.from_jinja(macro_text) + failures = checker.type_check() + assert failures + assert any(f for f in failures if f.type == FailureType.UNKNOWN_TYPE) From de91f53372f80e49f933cdf8b1dd72d643fdf316 Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Mon, 13 Jan 2025 16:14:49 -0500 Subject: [PATCH 4/5] Add changelog comment. --- .changes/unreleased/Features-20250113-161439.yaml | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 .changes/unreleased/Features-20250113-161439.yaml diff --git a/.changes/unreleased/Features-20250113-161439.yaml b/.changes/unreleased/Features-20250113-161439.yaml new file mode 100644 index 00000000..e72b7907 --- /dev/null +++ b/.changes/unreleased/Features-20250113-161439.yaml @@ -0,0 +1,6 @@ +kind: Features +body: Move type checking code into dbt-common +time: 2025-01-13T16:14:39.573226-05:00 +custom: + Author: peterallenwebb + Issue: "230" From 7b081b36443842d6a3d39e0329e5c4ed677b03eb Mon Sep 17 00:00:00 2001 From: Peter Allen Webb Date: Mon, 13 Jan 2025 19:36:24 -0500 Subject: [PATCH 5/5] Formatting --- dbt_common/clients/jinja_macro_call.py | 74 +++++++++++++++++++++----- tests/unit/test_jinja_macro_call.py | 18 +++++-- 2 files changed, 76 insertions(+), 16 deletions(-) diff --git a/dbt_common/clients/jinja_macro_call.py b/dbt_common/clients/jinja_macro_call.py index cd215418..b75e24e2 100644 --- a/dbt_common/clients/jinja_macro_call.py +++ b/dbt_common/clients/jinja_macro_call.py @@ -1,6 +1,6 @@ import dataclasses from enum import Enum -from typing import Any, Dict, List, Optional, Iterable +from typing import Any, Dict, List, Optional import jinja2 import jinja2.nodes @@ -18,11 +18,13 @@ class FailureType(Enum): EXTRA_ARGUMENT = "extra_arg" MISSING_ARGUMENT = "missing_arg" + @dataclasses.dataclass class TypeCheckFailure: type: FailureType msg: str + @dataclasses.dataclass class MacroCallChecker: """An instance of this class represents a jinja macro call in a template @@ -56,31 +58,61 @@ def check(self, macro_text: str) -> List[TypeCheckFailure]: target_type = macro_checker.arg_types[i] unassigned_args.remove(target_name) if arg_type is not None and target_type is not None and arg_type != target_type: - failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, f"Expected type {target_type.name} for argument {target_name} but found {arg_type.name}/")) + failures.append( + TypeCheckFailure( + FailureType.TYPE_MISMATCH, + f"Expected type {target_type.name} for argument {target_name} but found {arg_type.name}/", + ) + ) # Each keyword argument in this call should correspond to an expected # argument that has not already been assigned, and have a compatible type. for arg_name, arg_type in self.kwarg_types.items(): if arg_name not in macro_checker.args: - failures.append(TypeCheckFailure(FailureType.EXTRA_ARGUMENT, f"Unexpected keyword argument {arg_name}.")) + failures.append( + TypeCheckFailure( + FailureType.EXTRA_ARGUMENT, f"Unexpected keyword argument {arg_name}." + ) + ) elif arg_name not in unassigned_args: - failures.append(TypeCheckFailure(FailureType.EXTRA_ARGUMENT, f"Argument {arg_name} was specified more than once.")) + failures.append( + TypeCheckFailure( + FailureType.EXTRA_ARGUMENT, + f"Argument {arg_name} was specified more than once.", + ) + ) else: unassigned_args.remove(arg_name) expected_type = macro_checker.get_arg_type(arg_name) - if arg_type is not None and expected_type is not None and arg_type != expected_type: - failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, f"Expected type {expected_type.name} for argument {arg_name} but found {arg_type.name}/")) + if ( + arg_type is not None + and expected_type is not None + and arg_type != expected_type + ): + failures.append( + TypeCheckFailure( + FailureType.TYPE_MISMATCH, + f"Expected type {expected_type.name} for argument {arg_name} but found {arg_type.name}/", + ) + ) # Any remaining unassigned parameters must have a default. for arg_name in unassigned_args: if not macro_checker.has_default(arg_name): - failures.append(TypeCheckFailure(FailureType.MISSING_ARGUMENT, f"Missing argument {arg_name}.")) + failures.append( + TypeCheckFailure(FailureType.MISSING_ARGUMENT, f"Missing argument {arg_name}.") + ) # Check that any arguments specified by keyword have the correct type for arg_name, arg_type in self.kwarg_types.items(): expected_type = macro_checker.get_arg_type(arg_name) if arg_type is not None and expected_type is not None and arg_type != expected_type: - failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, f"Expected type {expected_type.name} as argument {arg_name} but found {arg_type.name}/")) + failures.append( + TypeCheckFailure( + FailureType.TYPE_MISMATCH, + f"Expected type {expected_type.name} as argument {arg_name} but found {arg_type.name}/", + ) + ) return failures @@ -143,23 +175,39 @@ def check(t: Optional[MacroType]) -> List[TypeCheckFailure]: failures: List[TypeCheckFailure] = [] if t.name == "Dict": if len(t.type_params) != 2: - failures.append(TypeCheckFailure(FailureType.PARAMETER_COUNT, f"Expected two type parameters for Dict[], found {len(t.type_params)}.")) + failures.append( + TypeCheckFailure( + FailureType.PARAMETER_COUNT, + f"Expected two type parameters for Dict[], found {len(t.type_params)}.", + ) + ) else: if t.type_params[0].name not in PRIMITIVE_TYPES: - failures.append(TypeCheckFailure(FailureType.TYPE_MISMATCH, "First type parameter of Dict[] must be a primitive type.")) + failures.append( + TypeCheckFailure( + FailureType.TYPE_MISMATCH, + "First type parameter of Dict[] must be a primitive type.", + ) + ) failures.extend(TypeChecker.check(t.type_params[1])) elif t.name in ("List", "Optional"): if len(t.type_params) != 1: - failures.append(TypeCheckFailure(FailureType.PARAMETER_COUNT, f"Expected one type parameter for {t.name}[], found {len(t.type_params)}.")) + failures.append( + TypeCheckFailure( + FailureType.PARAMETER_COUNT, + f"Expected one type parameter for {t.name}[], found {len(t.type_params)}.", + ) + ) failures.extend(TypeChecker.check(t.type_params[0])) else: - failures.append(TypeCheckFailure(FailureType.UNKNOWN_TYPE, f"Unknown type {t.name} encountered.")) + failures.append( + TypeCheckFailure(FailureType.UNKNOWN_TYPE, f"Unknown type {t.name} encountered.") + ) return failures - @staticmethod def get_type(param: Any) -> Optional[MacroType]: if isinstance(param, jinja2.nodes.Name): diff --git a/tests/unit/test_jinja_macro_call.py b/tests/unit/test_jinja_macro_call.py index 4bd12cf9..1b3c95e0 100644 --- a/tests/unit/test_jinja_macro_call.py +++ b/tests/unit/test_jinja_macro_call.py @@ -1,5 +1,11 @@ from dbt_common.clients.jinja import MacroType -from dbt_common.clients.jinja_macro_call import PRIMITIVE_TYPES, DBT_CLASSES, FailureType, MacroCallChecker, MacroChecker +from dbt_common.clients.jinja_macro_call import ( + PRIMITIVE_TYPES, + DBT_CLASSES, + FailureType, + MacroCallChecker, + MacroChecker, +) single_param_macro_text = """{% macro call_me(param: TYPE) %} {% endmacro %}""" @@ -23,6 +29,7 @@ def test_dbt_class_type_checks() -> None: failures = call.check(macro_text) assert not failures + def test_type_checks_wrong() -> None: """Test that calls to annotated macros with incorrect types fail type checks.""" for type_name in PRIMITIVE_TYPES + DBT_CLASSES: @@ -62,7 +69,9 @@ def test_too_few_pos_args() -> None: def test_unknown_kwarg() -> None: - call = MacroCallChecker("call_me", "", [MacroType("int"), MacroType("int")], {"unk": MacroType("str")}) + call = MacroCallChecker( + "call_me", "", [MacroType("int"), MacroType("int")], {"unk": MacroType("str")} + ) failures = call.check(kwarg_param_macro_text) assert len(failures) == 1 assert failures[0].type == FailureType.EXTRA_ARGUMENT @@ -70,7 +79,9 @@ def test_unknown_kwarg() -> None: def test_kwarg_type() -> None: """Test that annotated kwargs pass type checks when used by name.""" - call = MacroCallChecker("call_me", "", [MacroType("int"), MacroType("int")], {"arg3": MacroType("str")}) + call = MacroCallChecker( + "call_me", "", [MacroType("int"), MacroType("int")], {"arg3": MacroType("str")} + ) failures = call.check(kwarg_param_macro_text) assert not failures @@ -81,6 +92,7 @@ def test_wrong_kwarg_type() -> None: failures = call.check(kwarg_param_macro_text) assert failures[0].type == FailureType.TYPE_MISMATCH + # TODO: Test detection of macro with invalid default value for param type # TODO: Test detection of macro called with invalid variable parameter, as known from macro parameter annotation.