diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ed41416..e8f7941 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,14 +18,14 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python: ["3.10", "3.11"] + python: ["3.11", "3.12"] steps: - uses: actions/checkout@v4 - name: Set up PDM uses: pdm-project/setup-pdm@v4 with: - python-version: "3.10" + python-version: "3.11" - name: Install dependencies run: | pdm sync -d diff --git a/README.md b/README.md index d3c92a4..e7ecde0 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Linter for dbt model metadata. You'll need the following prerequisites: -- Any Python version starting from 3.10 +- Any Python version starting from 3.11 - [pre-commit](https://pre-commit.com/) - [PDM](https://pdm-project.org/2.12/) diff --git a/docs/reference/config.md b/docs/reference/config.md new file mode 100644 index 0000000..b3687e4 --- /dev/null +++ b/docs/reference/config.md @@ -0,0 +1,3 @@ +# Config + +::: dbt_score.config diff --git a/mkdocs.yml b/mkdocs.yml index b9fcd94..16c887d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -8,6 +8,7 @@ nav: - Home: index.md - Reference: - reference/cli.md + - reference/config.md - reference/exceptions.md - reference/evaluation.md - reference/models.md diff --git a/pdm.lock b/pdm.lock index ecd8b52..f1de9f0 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "dev", "docs", "lint", "test"] strategy = ["cross_platform"] lock_version = "4.4.1" -content_hash = "sha256:41d1ad10106c411809e42ab06b1bed7dabdf42afba377f62c45df6189fe01986" +content_hash = "sha256:c3a671a78ccbbea1806a039799632f5f7d668036bef35756fd42240dc9dbfbea" [[package]] name = "agate" @@ -288,7 +288,6 @@ requires_python = ">=3.8" summary = "Code coverage measurement for Python" dependencies = [ "coverage==7.4.3", - "tomli; python_full_version <= \"3.11.0a6\"", ] files = [ {file = "coverage-7.4.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:8580b827d4746d47294c0e0b92854c85a92c2227927433998f0d3320ae8a71b6"}, @@ -414,16 +413,6 @@ files = [ {file = "distlib-0.3.8.tar.gz", hash = "sha256:1530ea13e350031b6312d8580ddb6b27a104275a31106523b8f123787f494f64"}, ] -[[package]] -name = "exceptiongroup" -version = "1.2.0" -requires_python = ">=3.7" -summary = "Backport of PEP 654 (exception groups)" -files = [ - {file = "exceptiongroup-1.2.0-py3-none-any.whl", hash = "sha256:4bfd3996ac73b41e9b9628b04e079f193850720ea5945fc96a08633c66912f14"}, - {file = "exceptiongroup-1.2.0.tar.gz", hash = "sha256:91f5c769735f051a4290d52edd0858999b57e5876e9f85937691bd4c9fa3ed68"}, -] - [[package]] name = "filelock" version = "3.13.1" @@ -874,7 +863,6 @@ requires_python = ">=3.8" summary = "Optional static typing for Python" dependencies = [ "mypy-extensions>=1.0.0", - "tomli>=1.1.0; python_version < \"3.11\"", "typing-extensions>=4.1.0", ] files = [ @@ -1139,7 +1127,6 @@ requires_python = ">=3.8" summary = "API to interact with the python pyproject.toml based projects" dependencies = [ "packaging>=23.1", - "tomli>=2.0.1; python_version < \"3.11\"", ] files = [ {file = "pyproject_api-1.6.1-py3-none-any.whl", hash = "sha256:4c0116d60476b0786c88692cf4e325a9814965e2469c5998b830bba16b183675"}, @@ -1153,11 +1140,9 @@ requires_python = ">=3.8" summary = "pytest: simple powerful testing with Python" dependencies = [ "colorama; sys_platform == \"win32\"", - "exceptiongroup>=1.0.0rc8; python_version < \"3.11\"", "iniconfig", "packaging", "pluggy<2.0,>=1.4", - "tomli>=1; python_version < \"3.11\"", ] files = [ {file = "pytest-8.1.0-py3-none-any.whl", hash = "sha256:ee32db7af8de4629a455806befa90559f307424c07b8413ccfc30bf5b221dd7e"}, @@ -1523,16 +1508,6 @@ files = [ {file = "text_unidecode-1.3-py2.py3-none-any.whl", hash = "sha256:1311f10e8b895935241623731c2ba64f4c455287888b18189350b67134a822e8"}, ] -[[package]] -name = "tomli" -version = "2.0.1" -requires_python = ">=3.7" -summary = "A lil' TOML parser" -files = [ - {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, - {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, -] - [[package]] name = "tox" version = "4.13.0" @@ -1547,7 +1522,6 @@ dependencies = [ "platformdirs>=4.1", "pluggy>=1.3", "pyproject-api>=1.6.1", - "tomli>=2.0.1; python_version < \"3.11\"", "virtualenv>=20.25", ] files = [ @@ -1561,7 +1535,6 @@ version = "0.7.2" requires_python = ">=3.7" summary = "A plugin for tox that utilizes PDM as the package manager and installer" dependencies = [ - "tomli; python_version < \"3.11\"", "tox>=4.0", ] files = [ diff --git a/pyproject.toml b/pyproject.toml index f78b704..aa0e99b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ dependencies = [ "dbt-core>=1.5", "click>=7.1.1, <9.0.0", ] -requires-python = ">=3.10" +requires-python = ">=3.11" readme = "README.md" license = {text = "MIT"} @@ -92,7 +92,7 @@ max-args = 6 [tool.ruff.lint.per-file-ignores] "tests/**/*.py" = [ - "PLR2004", # magic value comparisons + "PLR2004", # Magic value comparisons ] ### Coverage ### diff --git a/src/dbt_score/cli.py b/src/dbt_score/cli.py index 4ace740..e629996 100644 --- a/src/dbt_score/cli.py +++ b/src/dbt_score/cli.py @@ -7,6 +7,7 @@ from click.core import ParameterSource from dbt.cli.options import MultiOption +from dbt_score.config import Config from dbt_score.lint import lint_dbt_project from dbt_score.parse import dbt_parse, get_default_manifest_path @@ -57,7 +58,10 @@ def lint(select: tuple[str], manifest: Path, run_dbt_parse: bool) -> None: if manifest_provided and run_dbt_parse: raise click.UsageError("--run-dbt-parse cannot be used with --manifest.") + config = Config() + config.load() + if run_dbt_parse: dbt_parse() - lint_dbt_project(manifest) + lint_dbt_project(manifest, config) diff --git a/src/dbt_score/config.py b/src/dbt_score/config.py new file mode 100644 index 0000000..aaa360c --- /dev/null +++ b/src/dbt_score/config.py @@ -0,0 +1,72 @@ +"""This module is responsible for loading configuration.""" + +import logging +import tomllib +from pathlib import Path +from typing import Any, Final + +from dbt_score.rule import RuleConfig + +logger = logging.getLogger(__name__) + +DEFAULT_CONFIG_FILE = "pyproject.toml" + + +class Config: + """Configuration for dbt-score.""" + + _main_section: Final[str] = "tool.dbt-score" + _options: Final[list[str]] = ["rule_namespaces", "disabled_rules"] + _rules_section: Final[str] = f"{_main_section}.rules" + + def __init__(self) -> None: + """Initialize the Config object.""" + self.rule_namespaces: list[str] = ["dbt_score_rules"] + self.disabled_rules: list[str] = [] + self.rules_config: dict[str, RuleConfig] = {} + self.config_file: Path | None = None + + def set_option(self, option: str, value: Any) -> None: + """Set an option in the config.""" + setattr(self, option, value) + + def _load_toml_file(self, file: str) -> None: + """Load the options from a TOML file.""" + with open(file, "rb") as f: + toml_data = tomllib.load(f) + + tools = toml_data.get("tool", {}) + dbt_score_config = tools.get("dbt-score", {}) + rules_config = dbt_score_config.pop("rules", {}) + + # Main configuration + for option, value in dbt_score_config.items(): + if option in self._options: + self.set_option(option, value) + elif not isinstance( + value, dict + ): # If value is a dictionary, it's another section + logger.warning( + f"Option {option} in {self._main_section} not supported." + ) + + # Rule configuration + self.rules_config = { + name: RuleConfig.from_dict(config) for name, config in rules_config.items() + } + + @staticmethod + def get_config_file(directory: Path) -> Path | None: + """Get the config file.""" + candidates = [directory] + candidates.extend(directory.parents) + for path in candidates: + config_file = path / DEFAULT_CONFIG_FILE + if config_file.exists(): + return config_file + + def load(self) -> None: + """Load the config.""" + config_file = self.get_config_file(Path.cwd()) + if config_file: + self._load_toml_file(str(config_file)) diff --git a/src/dbt_score/evaluation.py b/src/dbt_score/evaluation.py index 3630a4d..284f1db 100644 --- a/src/dbt_score/evaluation.py +++ b/src/dbt_score/evaluation.py @@ -51,15 +51,13 @@ def __init__( def evaluate(self) -> None: """Evaluate all rules.""" - # Instantiate all rules. In case they keep state across calls, this must be - # done only once. - rules = [rule_class() for rule_class in self._rule_registry.rules.values()] + rules = self._rule_registry.rules.values() for model in self._manifest_loader.models: self.results[model] = {} for rule in rules: try: - result: RuleViolation | None = rule.evaluate(model) + result: RuleViolation | None = rule.evaluate(model, **rule.config) except Exception as e: self.results[model][rule.__class__] = e else: diff --git a/src/dbt_score/lint.py b/src/dbt_score/lint.py index 6003f26..5fe1a40 100644 --- a/src/dbt_score/lint.py +++ b/src/dbt_score/lint.py @@ -2,6 +2,7 @@ from pathlib import Path +from dbt_score.config import Config from dbt_score.evaluation import Evaluation from dbt_score.formatters.human_readable_formatter import HumanReadableFormatter from dbt_score.models import ManifestLoader @@ -9,12 +10,12 @@ from dbt_score.scoring import Scorer -def lint_dbt_project(manifest_path: Path) -> None: +def lint_dbt_project(manifest_path: Path, config: Config) -> None: """Lint dbt manifest.""" if not manifest_path.exists(): raise FileNotFoundError(f"Manifest not found at {manifest_path}.") - rule_registry = RuleRegistry() + rule_registry = RuleRegistry(config) rule_registry.load_all() manifest_loader = ManifestLoader(manifest_path) diff --git a/src/dbt_score/rule.py b/src/dbt_score/rule.py index 76cc5b1..541c06a 100644 --- a/src/dbt_score/rule.py +++ b/src/dbt_score/rule.py @@ -1,6 +1,8 @@ """Rule definitions.""" -from dataclasses import dataclass +import inspect +import typing +from dataclasses import dataclass, field from enum import Enum from typing import Any, Callable, Type, TypeAlias, overload @@ -16,6 +18,26 @@ class Severity(Enum): CRITICAL = 4 +@dataclass +class RuleConfig: + """Configuration for a rule.""" + + severity: Severity | None = None + config: dict[str, Any] = field(default_factory=dict) + + @staticmethod + def from_dict(rule_config: dict[str, Any]) -> "RuleConfig": + """Create a RuleConfig from a dictionary.""" + config = rule_config.copy() + severity = ( + Severity(config.pop("severity", None)) + if "severity" in rule_config + else None + ) + + return RuleConfig(severity=severity, config=config) + + @dataclass class RuleViolation: """The violation of a rule.""" @@ -31,6 +53,13 @@ class Rule: description: str severity: Severity = Severity.MEDIUM + default_config: typing.ClassVar[dict[str, Any]] = {} + + def __init__(self, rule_config: RuleConfig | None = None) -> None: + """Initialize the rule.""" + self.config: dict[str, Any] = {} + if rule_config: + self.process_config(rule_config) def __init_subclass__(cls, **kwargs) -> None: # type: ignore """Initializes the subclass.""" @@ -38,10 +67,33 @@ def __init_subclass__(cls, **kwargs) -> None: # type: ignore if not hasattr(cls, "description"): raise AttributeError("Subclass must define class attribute `description`.") + def process_config(self, rule_config: RuleConfig) -> None: + """Process the rule config.""" + config = self.default_config.copy() + + # Overwrite default rule configuration + for k, v in rule_config.config.items(): + if k in self.default_config: + config[k] = v + else: + raise AttributeError( + f"Unknown rule parameter: {k} for rule {self.source()}." + ) + + self.set_severity( + rule_config.severity + ) if rule_config.severity else rule_config.severity + self.config = config + def evaluate(self, model: Model) -> RuleViolation | None: """Evaluates the rule.""" raise NotImplementedError("Subclass must implement method `evaluate`.") + @classmethod + def set_severity(cls, severity: Severity) -> None: + """Set the severity of the rule.""" + cls.severity = severity + @classmethod def source(cls) -> str: """Return the source of the rule, i.e. a fully qualified name.""" @@ -106,6 +158,13 @@ def wrapped_func(self: Rule, *args: Any, **kwargs: Any) -> RuleViolation | None: """Wrap func to add `self`.""" return func(*args, **kwargs) + # Get default parameters from the rule definition + default_config = { + key: val.default + for key, val in inspect.signature(func).parameters.items() + if val.default != inspect.Parameter.empty + } + # Create the rule class inheriting from Rule rule_class = type( func.__name__, @@ -113,6 +172,7 @@ def wrapped_func(self: Rule, *args: Any, **kwargs: Any) -> RuleViolation | None: { "description": rule_description, "severity": severity, + "default_config": default_config, "evaluate": wrapped_func, # Forward origin of the decorated function "__qualname__": func.__qualname__, # https://peps.python.org/pep-3155/ diff --git a/src/dbt_score/rule_registry.py b/src/dbt_score/rule_registry.py index 8e001ff..97b0fca 100644 --- a/src/dbt_score/rule_registry.py +++ b/src/dbt_score/rule_registry.py @@ -8,25 +8,32 @@ import pkgutil from typing import Iterator, Type +from dbt_score.config import Config from dbt_score.exceptions import DuplicatedRuleException -from dbt_score.rule import Rule +from dbt_score.rule import Rule, RuleConfig logger = logging.getLogger(__name__) -THIRD_PARTY_RULES_NAMESPACE = "dbt_score_rules" - class RuleRegistry: """A container for configured rules.""" - def __init__(self) -> None: + def __init__(self, config: Config) -> None: """Instantiate a rule registry.""" + self.config = config self._rules: dict[str, Type[Rule]] = {} + self._initialized_rules: dict[str, Rule] = {} + + def init_rules(self) -> None: + """Initialize rules.""" + for rule_name, rule_class in self._rules.items(): + rule_config = self.config.rules_config.get(rule_name, RuleConfig()) + self._initialized_rules[rule_name] = rule_class(rule_config=rule_config) @property - def rules(self) -> dict[str, Type[Rule]]: + def rules(self) -> dict[str, Rule]: """Get all rules.""" - return self._rules + return self._initialized_rules def _walk_packages(self, namespace_name: str) -> Iterator[str]: """Walk packages and sub-packages recursively.""" @@ -50,14 +57,18 @@ def _load(self, namespace_name: str) -> None: for obj_name in dir(module): obj = module.__dict__[obj_name] if type(obj) is type and issubclass(obj, Rule) and obj is not Rule: - self._add_rule(obj_name, obj) + self._add_rule(obj) - def _add_rule(self, name: str, rule: Type[Rule]) -> None: - if name in self.rules: - raise DuplicatedRuleException(name) - self._rules[name] = rule + def _add_rule(self, rule: Type[Rule]) -> None: + """Add a rule.""" + if rule.source() in self._rules: + raise DuplicatedRuleException(rule.source()) + if rule.source() not in self.config.disabled_rules: + self._rules[rule.source()] = rule def load_all(self) -> None: """Load all rules, core and third-party.""" self._load("dbt_score.rules") - self._load(THIRD_PARTY_RULES_NAMESPACE) + for namespace in self.config.rule_namespaces: + self._load(namespace) + self.init_rules() diff --git a/tests/conftest.py b/tests/conftest.py index 4b49fdd..980cb14 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,8 +5,30 @@ from typing import Any, Type from dbt_score import Model, Rule, RuleViolation, Severity, rule +from dbt_score.config import Config from pytest import fixture +# Configuration + + +@fixture() +def default_config() -> Config: + """Return a default Config object.""" + return Config() + + +@fixture +def valid_config_path() -> Path: + """Return the path of the configuration.""" + return Path(__file__).parent / "resources" / "pyproject.toml" + + +@fixture +def invalid_config_path() -> Path: + """Return the path of the configuration.""" + return Path(__file__).parent / "resources" / "invalid_pyproject.toml" + + # Manifest @@ -156,6 +178,21 @@ def rule_severity_critical(model: Model) -> RuleViolation | None: return rule_severity_critical +@fixture +def rule_with_config() -> Type[Rule]: + """An example rule with additional configuration.""" + + @rule + def rule_with_config( + model: Model, model_name: str = "model1" + ) -> RuleViolation | None: + """Rule with additional configuration.""" + if model.name != model_name: + return RuleViolation(message=model_name) + + return rule_with_config + + @fixture def rule_error() -> Type[Rule]: """An example rule which fails to run.""" diff --git a/tests/resources/invalid_pyproject.toml b/tests/resources/invalid_pyproject.toml new file mode 100644 index 0000000..77b5842 --- /dev/null +++ b/tests/resources/invalid_pyproject.toml @@ -0,0 +1,2 @@ +[tool.dbt-score] +foo = "bar" diff --git a/tests/resources/pyproject.toml b/tests/resources/pyproject.toml new file mode 100644 index 0000000..ee8d495 --- /dev/null +++ b/tests/resources/pyproject.toml @@ -0,0 +1,13 @@ +[tool.dbt-score] +rule_namespaces = ["foo", "tests"] +disabled_rules = ["foo.foo", "tests.bar"] + + +[tool.dbt-score.rules."foo.bar"] +severity=4 + +[tool.dbt-score.rules."tests.conftest.rule_with_config"] +model_name="model2" + +[tool.dbt-score.rules."tests.rules.example.rule_test_example"] +severity=4 diff --git a/tests/test_cli.py b/tests/test_cli.py index 0146f04..f22f7c7 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,5 +1,7 @@ """Test the CLI.""" +from unittest.mock import patch + import pytest from click.testing import CliRunner from dbt_score.cli import lint @@ -8,17 +10,19 @@ def test_invalid_options(): """Test invalid cli options.""" runner = CliRunner() - result = runner.invoke( - lint, ["--manifest", "fake_manifest.json", "--run-dbt-parse"] - ) - assert result.exit_code == 2 # pylint: disable=PLR2004 + with patch("dbt_score.cli.Config._load_toml_file"): + result = runner.invoke( + lint, ["--manifest", "fake_manifest.json", "--run-dbt-parse"] + ) + assert result.exit_code == 2 # pylint: disable=PLR2004 def test_lint_existing_manifest(manifest_path): """Test lint with an existing manifest.""" - runner = CliRunner() - result = runner.invoke(lint, ["--manifest", manifest_path]) - assert result.exit_code == 0 + with patch("dbt_score.cli.Config._load_toml_file"): + runner = CliRunner() + result = runner.invoke(lint, ["--manifest", manifest_path]) + assert result.exit_code == 0 def test_lint_non_existing_manifest(): @@ -27,10 +31,12 @@ def test_lint_non_existing_manifest(): # Provide manifest in command line with pytest.raises(FileNotFoundError): - runner.invoke( - lint, ["--manifest", "fake_manifest.json"], catch_exceptions=False - ) + with patch("dbt_score.cli.Config._load_toml_file"): + runner.invoke( + lint, ["--manifest", "fake_manifest.json"], catch_exceptions=False + ) # Use default manifest path with pytest.raises(FileNotFoundError): - runner.invoke(lint, catch_exceptions=False) + with patch("dbt_score.cli.Config._load_toml_file"): + runner.invoke(lint, catch_exceptions=False) diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 0000000..a29a683 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,62 @@ +"""Tests for the module config_parser.""" +from pathlib import Path + +import pytest +from dbt_score.config import Config +from dbt_score.rule import RuleConfig, Severity + + +def test_load_valid_toml_file(valid_config_path): + """Test that a valid config file loads correctly.""" + config = Config() + config._load_toml_file(str(valid_config_path)) + assert config.rule_namespaces == ["foo", "tests"] + assert config.disabled_rules == ["foo.foo", "tests.bar"] + assert config.rules_config["foo.bar"].severity == Severity.CRITICAL + assert ( + config.rules_config["tests.rules.example.rule_test_example"].severity + == Severity.CRITICAL + ) + + +def test_load_invalid_toml_file(caplog, invalid_config_path): + """Test that an invalid config file logs a warning.""" + config = Config() + config._load_toml_file(str(invalid_config_path)) + assert "Option foo in tool.dbt-score not supported." in caplog.text + + +def test_invalid_rule_config(rule_severity_low): + """Test that an invalid rule config raises an exception.""" + config = RuleConfig(config={"foo": "bar"}) + with pytest.raises( + AttributeError, + match="Unknown rule parameter: foo for rule " + "tests.conftest.rule_severity_low.", + ): + rule_severity_low(config) + + +def test_valid_rule_config(valid_config_path, rule_with_config): + """Test that a valid rule config can be loaded.""" + config = RuleConfig(severity=Severity(4), config={"model_name": "baz"}) + rule_with_config = rule_with_config(config) + assert rule_with_config.severity == Severity.CRITICAL + assert rule_with_config.default_config == {"model_name": "model1"} + assert rule_with_config.config == {"model_name": "baz"} + + +def test_get_config_file(): + """Test that the config file is found in the current directory.""" + directory = Path(__file__).parent / "resources" + config = Config() + config_file = config.get_config_file(directory) + assert config_file == directory / "pyproject.toml" + + +def test_get_parent_config_file(): + """Test that the config file is found in the parent directory.""" + directory = Path(__file__).parent / "resources" / "sub_dir" + config = Config() + config_file = config.get_config_file(directory) + assert config_file == directory.parent / "pyproject.toml" diff --git a/tests/test_evaluation.py b/tests/test_evaluation.py index c03981f..99dd136 100644 --- a/tests/test_evaluation.py +++ b/tests/test_evaluation.py @@ -2,6 +2,7 @@ from unittest.mock import Mock +from dbt_score.config import Config from dbt_score.evaluation import Evaluation from dbt_score.models import ManifestLoader from dbt_score.rule import RuleViolation @@ -14,17 +15,19 @@ def test_evaluation_low_medium_high( rule_severity_medium, rule_severity_high, rule_error, + default_config, ): """Test rule evaluation with a combination of LOW, MEDIUM and HIGH severity.""" manifest_loader = ManifestLoader(manifest_path) mock_formatter = Mock() mock_scorer = Mock() - rule_registry = RuleRegistry() - rule_registry._add_rule("rule_severity_low", rule_severity_low) - rule_registry._add_rule("rule_severity_medium", rule_severity_medium) - rule_registry._add_rule("rule_severity_high", rule_severity_high) - rule_registry._add_rule("rule_error", rule_error) + rule_registry = RuleRegistry(default_config) + rule_registry._add_rule(rule_severity_low) + rule_registry._add_rule(rule_severity_medium) + rule_registry._add_rule(rule_severity_high) + rule_registry._add_rule(rule_error) + rule_registry.init_rules() evaluation = Evaluation( rule_registry=rule_registry, @@ -55,16 +58,15 @@ def test_evaluation_low_medium_high( def test_evaluation_critical( - manifest_path, - rule_severity_low, - rule_severity_critical, + manifest_path, rule_severity_low, rule_severity_critical, default_config ): """Test rule evaluation with a CRITICAL severity.""" manifest_loader = ManifestLoader(manifest_path) - rule_registry = RuleRegistry() - rule_registry._add_rule("rule_severity_low", rule_severity_low) - rule_registry._add_rule("rule_severity_critical", rule_severity_critical) + rule_registry = RuleRegistry(default_config) + rule_registry._add_rule(rule_severity_low) + rule_registry._add_rule(rule_severity_critical) + rule_registry.init_rules() evaluation = Evaluation( rule_registry=rule_registry, @@ -79,11 +81,11 @@ def test_evaluation_critical( assert isinstance(evaluation.results[model2][rule_severity_critical], RuleViolation) -def test_evaluation_no_rule(manifest_path): +def test_evaluation_no_rule(manifest_path, default_config): """Test rule evaluation when no rule exists.""" manifest_loader = ManifestLoader(manifest_path) - rule_registry = RuleRegistry() + rule_registry = RuleRegistry(default_config) evaluation = Evaluation( rule_registry=rule_registry, @@ -97,12 +99,12 @@ def test_evaluation_no_rule(manifest_path): assert len(results) == 0 -def test_evaluation_no_model(manifest_empty_path, rule_severity_low): +def test_evaluation_no_model(manifest_empty_path, rule_severity_low, default_config): """Test rule evaluation when no model exists.""" manifest_loader = ManifestLoader(manifest_empty_path) - rule_registry = RuleRegistry() - rule_registry._add_rule("rule_severity_low", rule_severity_low) + rule_registry = RuleRegistry(default_config) + rule_registry._add_rule(rule_severity_low) evaluation = Evaluation( rule_registry=rule_registry, @@ -116,11 +118,11 @@ def test_evaluation_no_model(manifest_empty_path, rule_severity_low): assert list(evaluation.scores.values()) == [] -def test_evaluation_no_model_no_rule(manifest_empty_path): +def test_evaluation_no_model_no_rule(manifest_empty_path, default_config): """Test rule evaluation when no rule and no model exists.""" manifest_loader = ManifestLoader(manifest_empty_path) - rule_registry = RuleRegistry() + rule_registry = RuleRegistry(default_config) evaluation = Evaluation( rule_registry=rule_registry, @@ -132,3 +134,34 @@ def test_evaluation_no_model_no_rule(manifest_empty_path): assert len(evaluation.results) == 0 assert list(evaluation.scores.values()) == [] + + +def test_evaluation_rule_with_config( + manifest_path, rule_with_config, valid_config_path +): + """Test rule evaluation with parameters.""" + manifest_loader = ManifestLoader(manifest_path) + model1 = manifest_loader.models[0] + model2 = manifest_loader.models[1] + + config = Config() + config._load_toml_file(str(valid_config_path)) + + rule_registry = RuleRegistry(config) + rule_registry._add_rule(rule_with_config) + rule_registry.init_rules() + + evaluation = Evaluation( + rule_registry=rule_registry, + manifest_loader=manifest_loader, + formatter=Mock(), + scorer=Mock(), + ) + evaluation.evaluate() + + assert ( + rule_with_config.default_config + != rule_registry.rules["tests.conftest.rule_with_config"].config + ) + assert evaluation.results[model1][rule_with_config] is not None + assert evaluation.results[model2][rule_with_config] is None diff --git a/tests/test_lint.py b/tests/test_lint.py index ed904b9..41243e2 100644 --- a/tests/test_lint.py +++ b/tests/test_lint.py @@ -3,6 +3,7 @@ from unittest.mock import patch +from dbt_score.config import Config from dbt_score.lint import lint_dbt_project @@ -12,6 +13,6 @@ def test_lint_dbt_project(mock_evaluation, manifest_path): # Instance of classes are the same Mocks mock_evaluation.return_value = mock_evaluation - lint_dbt_project(manifest_path) + lint_dbt_project(manifest_path, Config()) mock_evaluation.evaluate.assert_called_once() diff --git a/tests/test_rule_registry.py b/tests/test_rule_registry.py index 0d90166..3049dad 100644 --- a/tests/test_rule_registry.py +++ b/tests/test_rule_registry.py @@ -1,27 +1,55 @@ """Unit tests for the rule registry.""" import pytest +from dbt_score import Severity +from dbt_score.config import Config from dbt_score.exceptions import DuplicatedRuleException from dbt_score.rule_registry import RuleRegistry -def test_rule_registry_discovery(): +def test_rule_registry_discovery(default_config): """Ensure rules can be found in a given namespace recursively.""" - r = RuleRegistry() + r = RuleRegistry(default_config) r._load("tests.rules") - assert sorted(r.rules.keys()) == ["rule_test_example", "rule_test_nested_example"] + assert sorted(r._rules.keys()) == [ + "tests.rules.example.rule_test_example", + "tests.rules.nested.example.rule_test_nested_example", + ] -def test_rule_registry_no_duplicates(): +def test_disabled_rule_registry_discovery(): + """Ensure disabled rules are not discovered.""" + config = Config() + config.disabled_rules = ["tests.rules.nested.example.rule_test_nested_example"] + r = RuleRegistry(config) + r._load("tests.rules") + assert sorted(r._rules.keys()) == [ + "tests.rules.example.rule_test_example", + ] + + +def test_configured_rule_registry_discovery(valid_config_path): + """Ensure rules are discovered and configured correctly.""" + config = Config() + config._load_toml_file(str(valid_config_path)) + r = RuleRegistry(config) + r._load("tests.rules") + r.init_rules() + assert ( + r.rules["tests.rules.example.rule_test_example"].severity == Severity.CRITICAL + ) + + +def test_rule_registry_no_duplicates(default_config): """Ensure no duplicate rule names can coexist.""" - r = RuleRegistry() + r = RuleRegistry(default_config) r._load("tests.rules") with pytest.raises(DuplicatedRuleException): r._load("tests.rules") -def test_rule_registry_core_rules(): +def test_rule_registry_core_rules(default_config): """Ensure core rules are automatically discovered.""" - r = RuleRegistry() + r = RuleRegistry(default_config) r.load_all() assert len(r.rules) > 0