Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CT 1604 remove compiled classes #6384

Merged
merged 11 commits into from
Dec 7, 2022
7 changes: 7 additions & 0 deletions .changes/unreleased/Under the Hood-20221205-164948.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
kind: Under the Hood
body: Consolidate ParsedNode and CompiledNode classes
time: 2022-12-05T16:49:48.563583-05:00
custom:
Author: gshank
Issue: "6383"
PR: "6384"
15 changes: 4 additions & 11 deletions core/dbt/adapters/base/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
List,
Mapping,
Iterator,
Union,
Set,
)

Expand All @@ -38,9 +37,8 @@
)
from dbt.clients.agate_helper import empty_table, merge_tables, table_from_rows
from dbt.clients.jinja import MacroGenerator
from dbt.contracts.graph.compiled import CompileResultNode, CompiledSeedNode
from dbt.contracts.graph.manifest import Manifest, MacroManifest
from dbt.contracts.graph.parsed import ParsedSeedNode
from dbt.contracts.graph.nodes import ResultNode
from dbt.events.functions import fire_event, warn_or_error
from dbt.events.types import (
CacheMiss,
Expand All @@ -64,9 +62,6 @@
from dbt.adapters.cache import RelationsCache, _make_ref_key_msg


SeedModel = Union[ParsedSeedNode, CompiledSeedNode]


GET_CATALOG_MACRO_NAME = "get_catalog"
FRESHNESS_MACRO_NAME = "collect_freshness"

Expand Down Expand Up @@ -243,9 +238,7 @@ def nice_connection_name(self) -> str:
return conn.name

@contextmanager
def connection_named(
self, name: str, node: Optional[CompileResultNode] = None
) -> Iterator[None]:
def connection_named(self, name: str, node: Optional[ResultNode] = None) -> Iterator[None]:
try:
if self.connections.query_header is not None:
self.connections.query_header.set(name, node)
Expand All @@ -257,7 +250,7 @@ def connection_named(
self.connections.query_header.reset()

@contextmanager
def connection_for(self, node: CompileResultNode) -> Iterator[None]:
def connection_for(self, node: ResultNode) -> Iterator[None]:
with self.connection_named(node.unique_id, node):
yield

Expand Down Expand Up @@ -372,7 +365,7 @@ def _get_catalog_schemas(self, manifest: Manifest) -> SchemaSearchMap:
lowercase strings.
"""
info_schema_name_map = SchemaSearchMap()
nodes: Iterator[CompileResultNode] = chain(
nodes: Iterator[ResultNode] = chain(
[
node
for node in manifest.nodes.values()
Expand Down
4 changes: 2 additions & 2 deletions core/dbt/adapters/base/query_headers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from dbt.context.manifest import generate_query_header_context
from dbt.contracts.connection import AdapterRequiredConfig, QueryComment
from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.nodes import ResultNode
from dbt.contracts.graph.manifest import Manifest
from dbt.exceptions import RuntimeException

Expand Down Expand Up @@ -90,7 +90,7 @@ def add(self, sql: str) -> str:
def reset(self):
self.set("master", None)

def set(self, name: str, node: Optional[CompileResultNode]):
def set(self, name: str, node: Optional[ResultNode]):
wrapped: Optional[NodeWrapper] = None
if node is not None:
wrapped = NodeWrapper(node)
Expand Down
22 changes: 9 additions & 13 deletions core/dbt/adapters/base/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from dataclasses import dataclass
from typing import Optional, TypeVar, Any, Type, Dict, Union, Iterator, Tuple, Set

from dbt.contracts.graph.compiled import CompiledNode
from dbt.contracts.graph.parsed import ParsedSourceDefinition, ParsedNode
from dbt.contracts.graph.nodes import SourceDefinition, ParsedNode
from dbt.contracts.relation import (
RelationType,
ComponentName,
Expand Down Expand Up @@ -184,7 +183,7 @@ def quoted(self, identifier):
)

@classmethod
def create_from_source(cls: Type[Self], source: ParsedSourceDefinition, **kwargs: Any) -> Self:
def create_from_source(cls: Type[Self], source: SourceDefinition, **kwargs: Any) -> Self:
source_quoting = source.quoting.to_dict(omit_none=True)
source_quoting.pop("column", None)
quote_policy = deep_merge(
Expand All @@ -209,7 +208,7 @@ def add_ephemeral_prefix(name: str):
def create_ephemeral_from_node(
cls: Type[Self],
config: HasQuoting,
node: Union[ParsedNode, CompiledNode],
node: ParsedNode,
) -> Self:
# Note that ephemeral models are based on the name.
identifier = cls.add_ephemeral_prefix(node.name)
Expand All @@ -222,7 +221,7 @@ def create_ephemeral_from_node(
def create_from_node(
cls: Type[Self],
config: HasQuoting,
node: Union[ParsedNode, CompiledNode],
node: ParsedNode,
quote_policy: Optional[Dict[str, bool]] = None,
**kwargs: Any,
) -> Self:
Expand All @@ -243,21 +242,18 @@ def create_from_node(
def create_from(
cls: Type[Self],
config: HasQuoting,
node: Union[CompiledNode, ParsedNode, ParsedSourceDefinition],
node: Union[ParsedNode, SourceDefinition],
**kwargs: Any,
) -> Self:
if node.resource_type == NodeType.Source:
if not isinstance(node, ParsedSourceDefinition):
if not isinstance(node, SourceDefinition):
raise InternalException(
"type mismatch, expected ParsedSourceDefinition but got {}".format(type(node))
"type mismatch, expected SourceDefinition but got {}".format(type(node))
)
return cls.create_from_source(node, **kwargs)
else:
if not isinstance(node, (ParsedNode, CompiledNode)):
raise InternalException(
"type mismatch, expected ParsedNode or CompiledNode but "
"got {}".format(type(node))
)
if not isinstance(node, (ParsedNode)):
raise InternalException(f"type mismatch, expected ParsedNode but got {type(node)}")
return cls.create_from_node(config, node, **kwargs)

@classmethod
Expand Down
7 changes: 3 additions & 4 deletions core/dbt/adapters/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
import agate

from dbt.contracts.connection import Connection, AdapterRequiredConfig, AdapterResponse
from dbt.contracts.graph.compiled import CompiledNode, ManifestNode, NonSourceCompiledNode
from dbt.contracts.graph.parsed import ParsedNode, ParsedSourceDefinition
from dbt.contracts.graph.nodes import ParsedNode, SourceDefinition, ManifestNode
from dbt.contracts.graph.model_config import BaseConfig
from dbt.contracts.graph.manifest import Manifest
from dbt.contracts.relation import Policy, HasQuoting
Expand Down Expand Up @@ -51,7 +50,7 @@ def get_default_quote_policy(cls) -> Policy:
def create_from(
cls: Type[Self],
config: HasQuoting,
node: Union[CompiledNode, ParsedNode, ParsedSourceDefinition],
node: Union[ParsedNode, SourceDefinition],
) -> Self:
...

Expand All @@ -65,7 +64,7 @@ def compile_node(
node: ManifestNode,
manifest: Manifest,
extra_context: Optional[Dict[str, Any]] = None,
) -> NonSourceCompiledNode:
) -> ManifestNode:
...


Expand Down
5 changes: 2 additions & 3 deletions core/dbt/clients/jinja.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@
)

from dbt.clients._jinja_blocks import BlockIterator, BlockData, BlockTag
from dbt.contracts.graph.compiled import CompiledGenericTestNode
from dbt.contracts.graph.parsed import ParsedGenericTestNode
from dbt.contracts.graph.nodes import GenericTestNode

from dbt.exceptions import (
InternalException,
Expand Down Expand Up @@ -620,7 +619,7 @@ def extract_toplevel_blocks(

def add_rendered_test_kwargs(
context: Dict[str, Any],
node: Union[ParsedGenericTestNode, CompiledGenericTestNode],
node: GenericTestNode,
capture_macros: bool = False,
) -> None:
"""Render each of the test kwargs in the given context using the native
Expand Down
55 changes: 21 additions & 34 deletions core/dbt/compilation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from collections import defaultdict
from typing import List, Dict, Any, Tuple, cast, Optional
from typing import List, Dict, Any, Tuple, Optional

import networkx as nx # type: ignore
import pickle
Expand All @@ -12,15 +12,13 @@
from dbt.clients.system import make_directory
from dbt.context.providers import generate_runtime_model_context
from dbt.contracts.graph.manifest import Manifest, UniqueID
from dbt.contracts.graph.compiled import (
COMPILED_TYPES,
CompiledGenericTestNode,
from dbt.contracts.graph.nodes import (
ParsedNode,
ManifestNode,
GenericTestNode,
GraphMemberNode,
InjectedCTE,
ManifestNode,
NonSourceCompiledNode,
)
from dbt.contracts.graph.parsed import ParsedNode
from dbt.exceptions import (
dependency_not_found,
InternalException,
Expand All @@ -37,14 +35,6 @@
graph_file_name = "graph.gpickle"


def _compiled_type_for(model: ParsedNode):
if type(model) not in COMPILED_TYPES:
raise InternalException(
f"Asked to compile {type(model)} node, but it has no compiled form"
)
return COMPILED_TYPES[type(model)]


def print_compile_stats(stats):
names = {
NodeType.Model: "model",
Expand Down Expand Up @@ -177,15 +167,15 @@ def initialize(self):
# a dict for jinja rendering of SQL
def _create_node_context(
self,
node: NonSourceCompiledNode,
node: ManifestNode,
manifest: Manifest,
extra_context: Dict[str, Any],
) -> Dict[str, Any]:

context = generate_runtime_model_context(node, self.config, manifest)
context.update(extra_context)

if isinstance(node, CompiledGenericTestNode):
if isinstance(node, GenericTestNode):
# for test nodes, add a special keyword args value to the context
jinja.add_rendered_test_kwargs(context, node)

Expand Down Expand Up @@ -262,10 +252,10 @@ def _inject_ctes_into_sql(self, sql: str, ctes: List[InjectedCTE]) -> str:

def _recursively_prepend_ctes(
self,
model: NonSourceCompiledNode,
model: ManifestNode,
manifest: Manifest,
extra_context: Optional[Dict[str, Any]],
) -> Tuple[NonSourceCompiledNode, List[InjectedCTE]]:
) -> Tuple[ManifestNode, List[InjectedCTE]]:
"""This method is called by the 'compile_node' method. Starting
from the node that it is passed in, it will recursively call
itself using the 'extra_ctes'. The 'ephemeral' models do
Expand Down Expand Up @@ -306,8 +296,6 @@ def _recursively_prepend_ctes(
# This model has already been compiled, so it's been
# through here before
if getattr(cte_model, "compiled", False):
assert isinstance(cte_model, tuple(COMPILED_TYPES.values()))
cte_model = cast(NonSourceCompiledNode, cte_model)
new_prepended_ctes = cte_model.extra_ctes

# if the cte_model isn't compiled, i.e. first time here
Expand Down Expand Up @@ -344,7 +332,7 @@ def _recursively_prepend_ctes(

return model, prepended_ctes

# creates a compiled_node from the ManifestNode passed in,
# Sets compiled fields in the ManifestNode passed in,
# creates a "context" dictionary for jinja rendering,
# and then renders the "compiled_code" using the node, the
# raw_code and the context.
Expand All @@ -353,7 +341,7 @@ def _compile_node(
node: ManifestNode,
manifest: Manifest,
extra_context: Optional[Dict[str, Any]] = None,
) -> NonSourceCompiledNode:
) -> ManifestNode:
if extra_context is None:
extra_context = {}

Expand All @@ -366,41 +354,40 @@ def _compile_node(
"extra_ctes": [],
}
)
compiled_node = _compiled_type_for(node).from_dict(data)

if compiled_node.language == ModelLanguage.python:
if node.language == ModelLanguage.python:
# TODO could we also 'minify' this code at all? just aesthetic, not functional

# quoating seems like something very specific to sql so far
# for all python implementations we are seeing there's no quating.
# TODO try to find better way to do this, given that
original_quoting = self.config.quoting
self.config.quoting = {key: False for key in original_quoting.keys()}
context = self._create_node_context(compiled_node, manifest, extra_context)
context = self._create_node_context(node, manifest, extra_context)

postfix = jinja.get_rendered(
"{{ py_script_postfix(model) }}",
context,
node,
)
# we should NOT jinja render the python model's 'raw code'
compiled_node.compiled_code = f"{node.raw_code}\n\n{postfix}"
node.compiled_code = f"{node.raw_code}\n\n{postfix}"
# restore quoting settings in the end since context is lazy evaluated
self.config.quoting = original_quoting

else:
context = self._create_node_context(compiled_node, manifest, extra_context)
compiled_node.compiled_code = jinja.get_rendered(
context = self._create_node_context(node, manifest, extra_context)
node.compiled_code = jinja.get_rendered(
node.raw_code,
context,
node,
)

compiled_node.relation_name = self._get_relation_name(node)
node.relation_name = self._get_relation_name(node)

compiled_node.compiled = True
node.compiled = True

return compiled_node
return node

def write_graph_file(self, linker: Linker, manifest: Manifest):
filename = graph_file_name
Expand Down Expand Up @@ -507,7 +494,7 @@ def compile(self, manifest: Manifest, write=True, add_test_edges=False) -> Graph
return Graph(linker.graph)

# writes the "compiled_code" into the target/compiled directory
def _write_node(self, node: NonSourceCompiledNode) -> ManifestNode:
def _write_node(self, node: ManifestNode) -> ManifestNode:
if not node.extra_ctes_injected or node.resource_type == NodeType.Snapshot:
return node
fire_event(WritingInjectedSQLForNode(node_info=get_node_info()))
Expand All @@ -524,7 +511,7 @@ def compile_node(
manifest: Manifest,
extra_context: Optional[Dict[str, Any]] = None,
write: bool = True,
) -> NonSourceCompiledNode:
) -> ManifestNode:
"""This is the main entry point into this code. It's called by
CompileRunner.compile, GenericRPCRunner.compile, and
RunTask.get_hook_sql. It calls '_compile_node' to convert
Expand Down
6 changes: 3 additions & 3 deletions core/dbt/context/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from dbt.clients.jinja import get_rendered
from dbt.clients.yaml_helper import yaml, safe_load, SafeLoader, Loader, Dumper # noqa: F401
from dbt.constants import SECRET_ENV_PREFIX, DEFAULT_ENV_PLACEHOLDER
from dbt.contracts.graph.compiled import CompiledResource
from dbt.contracts.graph.nodes import Resource
from dbt.exceptions import (
CompilationException,
MacroReturn,
Expand Down Expand Up @@ -135,11 +135,11 @@ def __init__(
self,
context: Mapping[str, Any],
cli_vars: Mapping[str, Any],
node: Optional[CompiledResource] = None,
node: Optional[Resource] = None,
) -> None:
self._context: Mapping[str, Any] = context
self._cli_vars: Mapping[str, Any] = cli_vars
self._node: Optional[CompiledResource] = node
self._node: Optional[Resource] = node
self._merged: Mapping[str, Any] = self._generate_merged()

def _generate_merged(self) -> Mapping[str, Any]:
Expand Down
Loading