Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore parent tests added edges for build selection #7431

Merged
6 changes: 6 additions & 0 deletions .changes/unreleased/Fixes-20230421-172428.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Fixes
body: dbt build selection of tests' descendants
time: 2023-04-21T17:24:28.335866975+02:00
custom:
Author: b-luu
Issue: "7289"
2 changes: 1 addition & 1 deletion core/dbt/compilation.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,7 @@ def add_test_edges(self, linker: Linker, manifest: Manifest) -> None:
# is a subset of all upstream nodes of the current node,
# add an edge from the upstream test to the current node.
if test_depends_on.issubset(upstream_nodes):
linker.graph.add_edge(upstream_test, node_id)
linker.graph.add_edge(upstream_test, node_id, edge_type="parent_test")

def compile(self, manifest: Manifest, write=True, add_test_edges=False) -> Graph:
self.initialize()
Expand Down
17 changes: 15 additions & 2 deletions core/dbt/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,29 @@ def ancestors(self, node: UniqueId, max_depth: Optional[int] = None) -> Set[Uniq
"""Returns all nodes having a path to `node` in `graph`"""
if not self.graph.has_node(node):
raise DbtInternalError(f"Node {node} not found in the graph!")
filtered_graph = self.exclude_edge_type("parent_test")
return {
child
for _, child in nx.bfs_edges(self.graph, node, reverse=True, depth_limit=max_depth)
for _, child in nx.bfs_edges(filtered_graph, node, reverse=True, depth_limit=max_depth)
}

def descendants(self, node: UniqueId, max_depth: Optional[int] = None) -> Set[UniqueId]:
"""Returns all nodes reachable from `node` in `graph`"""
if not self.graph.has_node(node):
raise DbtInternalError(f"Node {node} not found in the graph!")
return {child for _, child in nx.bfs_edges(self.graph, node, depth_limit=max_depth)}
filtered_graph = self.exclude_edge_type("parent_test")
return {child for _, child in nx.bfs_edges(filtered_graph, node, depth_limit=max_depth)}

def exclude_edge_type(self, edge_type_to_exclude):
return nx.restricted_view(
self.graph,
nodes=[],
edges=(
(a, b)
for a, b in self.graph.edges
if self.graph[a][b].get("edge_type") == edge_type_to_exclude
),
)

def select_childrens_parents(self, selected: Set[UniqueId]) -> Set[UniqueId]:
ancestors_for = self.select_children(selected) | selected
Expand Down
21 changes: 21 additions & 0 deletions tests/functional/build/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,27 @@
- not_null
"""

models_triple_blocking__test_yml = """
version: 2

models:
- name: model_a
columns:
- name: id
tests:
- not_null
- name: model_b
columns:
- name: id
tests:
- not_null
- name: model_c
columns:
- name: id
tests:
- not_null
"""

models_interdependent__model_a_sql = """
select 1 as id
"""
Expand Down
32 changes: 32 additions & 0 deletions tests/functional/build/test_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
models_simple_blocking__model_a_sql,
models_simple_blocking__model_b_sql,
models_simple_blocking__test_yml,
models_triple_blocking__test_yml,
models_interdependent__test_yml,
models_interdependent__model_a_sql,
models_interdependent__model_b_sql,
Expand Down Expand Up @@ -196,3 +197,34 @@ def test_interdependent_models_fail(self, project):
actual = [str(r.status) for r in results]
expected = ["error"] * 4 + ["skipped"] * 7 + ["pass"] * 2 + ["success"] * 3
assert sorted(actual) == sorted(expected)


class TestDownstreamSelection:
@pytest.fixture(scope="class")
def models(self):
return {
"model_a.sql": models_simple_blocking__model_a_sql,
"model_b.sql": models_simple_blocking__model_b_sql,
"test.yml": models_simple_blocking__test_yml,
}

def test_downstream_selection(self, project):
"""Ensure that selecting test+ does not select model_a's other children"""
results = run_dbt(["build", "--select", "model_a not_null_model_a_id+"], expect_pass=True)
assert len(results) == 2


class TestLimitedUpstreamSelection:
@pytest.fixture(scope="class")
def models(self):
return {
"model_a.sql": models_interdependent__model_a_sql,
"model_b.sql": models_interdependent__model_b_sql,
"model_c.sql": models_interdependent__model_c_sql,
"test.yml": models_triple_blocking__test_yml,
}

def test_limited_upstream_selection(self, project):
"""Ensure that selecting 1+model_c only selects up to model_b (+ tests of both)"""
results = run_dbt(["build", "--select", "1+model_c"], expect_pass=True)
assert len(results) == 4
11 changes: 5 additions & 6 deletions tests/functional/defer_state/test_run_results_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,9 +216,9 @@ def test_build_run_results_state(self, project):
results = run_dbt(
["build", "--select", "result:fail+", "--state", "./state"], expect_pass=False
)
assert len(results) == 2
assert len(results) == 1
nodes = set([elem.node.name for elem in results])
assert nodes == {"table_model", "unique_view_model_id"}
assert nodes == {"unique_view_model_id"}

results = run_dbt(["ls", "--select", "result:fail+", "--state", "./state"])
assert len(results) == 1
Expand All @@ -240,9 +240,9 @@ def test_build_run_results_state(self, project):
results = run_dbt(
["build", "--select", "result:warn+", "--state", "./state"], expect_pass=True
)
assert len(results) == 2 # includes table_model to be run
assert len(results) == 1
nodes = set([elem.node.name for elem in results])
assert nodes == {"table_model", "unique_view_model_id"}
assert nodes == {"unique_view_model_id"}

results = run_dbt(["ls", "--select", "result:warn+", "--state", "./state"])
assert len(results) == 1
Expand Down Expand Up @@ -483,12 +483,11 @@ def test_concurrent_selectors_build_run_results_state(self, project):
],
expect_pass=False,
)
assert len(results) == 5
assert len(results) == 4
nodes = set([elem.node.name for elem in results])
assert nodes == {
"error_model",
"downstream_of_error_model",
"table_model_modified_example",
"table_model",
"unique_view_model_id",
}