Skip to content

Commit

Permalink
implement source patching
Browse files Browse the repository at this point in the history
  • Loading branch information
Jacob Beck committed Apr 28, 2020
1 parent c69f28e commit 38443cf
Show file tree
Hide file tree
Showing 9 changed files with 560 additions and 320 deletions.
3 changes: 3 additions & 0 deletions core/dbt/config/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,6 +645,9 @@ def as_v1(self):
# stuff any 'vars' entries into the old-style
# models/seeds/snapshots dicts
for project_name, items in dct['vars'].items():
if not isinstance(items, dict):
# can't translate top-level vars
continue
for cfgkey in ['models', 'seeds', 'snapshots']:
if project_name not in mutated[cfgkey]:
mutated[cfgkey][project_name] = {}
Expand Down
5 changes: 3 additions & 2 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@

from hologram import JsonSchemaMixin

from dbt.contracts.graph.compiled import CompileResultNode, NonSourceNode
from dbt.contracts.graph.parsed import (
ParsedMacro, ParsedDocumentation, ParsedNodePatch, ParsedMacroPatch,
ParsedSourceDefinition
)
from dbt.contracts.graph.compiled import CompileResultNode, NonSourceNode
from dbt.contracts.util import Writable, Replaceable
from dbt.exceptions import (
raise_duplicate_resource_name, InternalException, raise_compiler_error,
Expand All @@ -33,6 +33,7 @@

NodeEdgeMap = Dict[str, List[str]]
MacroKey = Tuple[str, str]
SourceKey = Tuple[str, str]


@dataclass
Expand Down Expand Up @@ -142,7 +143,7 @@ class SourceFile(JsonSchemaMixin):
# any macro patches in this file. The entries are package, name pairs.
macro_patches: List[MacroKey] = field(default_factory=list)
# any source patches in this file. The entries are package, name pairs
source_patches: List[Tuple[str, str]] = field(default_factory=list)
source_patches: List[SourceKey] = field(default_factory=list)

@property
def search_key(self) -> Optional[str]:
Expand Down
58 changes: 57 additions & 1 deletion core/dbt/contracts/graph/parsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
List,
Dict,
Any,
Sequence,
Tuple,
Iterator,
)

from hologram import JsonSchemaMixin
Expand All @@ -15,7 +18,8 @@
from dbt.contracts.graph.unparsed import (
UnparsedNode, UnparsedDocumentation, Quoting, Docs,
UnparsedBaseNode, FreshnessThreshold, ExternalTable,
HasYamlMetadata, MacroArgument
HasYamlMetadata, MacroArgument, UnparsedSourceDefinition,
UnparsedSourceTableDefinition, UnparsedColumn, TestDef
)
from dbt.contracts.util import Replaceable
from dbt.logger import GLOBAL_LOGGER as logger # noqa
Expand Down Expand Up @@ -301,6 +305,58 @@ def search_name(self):
return self.name


def normalize_test(testdef: TestDef) -> Dict[str, Any]:
if isinstance(testdef, str):
return {testdef: {}}
else:
return testdef


@dataclass
class UnpatchedSourceDefinition(UnparsedBaseNode, HasUniqueID, HasFqn):
source: UnparsedSourceDefinition
table: UnparsedSourceTableDefinition
resource_type: NodeType = field(metadata={'restrict': [NodeType.Source]})

@property
def name(self) -> str:
return '{0.name}_{1.name}'.format(self.source, self.table)

@property
def quote_columns(self) -> Optional[bool]:
result = None
if self.source.quoting.column is not None:
result = self.source.quoting.column
if self.table.quoting.column is not None:
result = self.table.quoting.column
return result

@property
def columns(self) -> Sequence[UnparsedColumn]:
if self.table.columns is None:
return []
else:
return self.table.columns

def get_tests(
self
) -> Iterator[Tuple[Dict[str, Any], Optional[UnparsedColumn]]]:
for test in self.tests:
yield normalize_test(test), None

for column in self.columns:
if column.tests is not None:
for test in column.tests:
yield normalize_test(test), column

@property
def tests(self) -> List[TestDef]:
if self.table.tests is None:
return []
else:
return self.table.tests


@dataclass
class ParsedSourceDefinition(
UnparsedBaseNode,
Expand Down
53 changes: 49 additions & 4 deletions core/dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,12 @@ class UnparsedSourceTableDefinition(HasColumnTests, HasTests):
external: Optional[ExternalTable] = None
tags: List[str] = field(default_factory=list)

def to_dict(self, omit_none=True, validate=False):
result = super().to_dict(omit_none=omit_none, validate=validate)
if omit_none and self.freshness is None:
result['freshness'] = None
return result


@dataclass
class UnparsedSourceDefinition(JsonSchemaMixin, Replaceable):
Expand All @@ -267,9 +273,20 @@ class UnparsedSourceDefinition(JsonSchemaMixin, Replaceable):
def yaml_key(self) -> 'str':
return 'sources'

def to_dict(self, omit_none=True, validate=False):
result = super().to_dict(omit_none=omit_none, validate=validate)
if omit_none and self.freshness is None:
result['freshness'] = None
return result


@dataclass
class UnparsedSourceTablePatch(HasColumnDocs, HasTests):
class SourceTablePatch(JsonSchemaMixin):
name: str
description: Optional[str] = None
meta: Optional[Dict[str, Any]] = None
data_type: Optional[str] = None
docs: Optional[Docs] = None
loaded_at_field: Optional[str] = None
identifier: Optional[str] = None
quoting: Quoting = field(default_factory=Quoting)
Expand All @@ -278,11 +295,20 @@ class UnparsedSourceTablePatch(HasColumnDocs, HasTests):
)
external: Optional[ExternalTable] = None
tags: Optional[List[str]] = None
tests: Optional[List[TestDef]] = None
columns: Optional[Sequence[UnparsedColumn]] = None

def to_patch_dict(self) -> Dict[str, Any]:
dct = self.to_dict(omit_none=True)
remove_keys = ('name')
for key in remove_keys:
if key in dct:
del dct[key]

SourceTablePatch = Union[
UnparsedSourceTablePatch, UnparsedSourceTableDefinition
]
if self.freshness is None:
dct['freshness'] = None

return dct


@dataclass
Expand All @@ -306,6 +332,25 @@ class SourcePatch(JsonSchemaMixin, Replaceable):
tables: Optional[List[SourceTablePatch]] = None
tags: Optional[List[str]] = None

def to_patch_dict(self) -> Dict[str, Any]:
dct = self.to_dict(omit_none=True)
remove_keys = ('name', 'overrides', 'tables')
for key in remove_keys:
if key in dct:
del dct[key]

if self.freshness is None:
dct['freshness'] = None

return dct

def get_table_named(self, name: str) -> Optional[SourceTablePatch]:
if self.tables is not None:
for table in self.tables:
if table.name == name:
return table
return None


@dataclass
class SourcesContainer(ExtensibleJsonSchemaMixin):
Expand Down
12 changes: 9 additions & 3 deletions core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from dbt.contracts.graph.compiled import NonSourceNode
from dbt.contracts.graph.manifest import Manifest, FilePath, FileHash, Disabled
from dbt.contracts.graph.parsed import (
ParsedSourceDefinition, ParsedNode, ParsedMacro, ColumnInfo
ParsedSourceDefinition, ParsedNode, ParsedMacro, ColumnInfo,
)
from dbt.parser.base import BaseParser, Parser
from dbt.parser.analysis import AnalysisParser
Expand All @@ -34,6 +34,7 @@
from dbt.parser.search import FileBlock
from dbt.parser.seeds import SeedParser
from dbt.parser.snapshots import SnapshotParser
from dbt.parser.sources import patch_sources
from dbt.version import __version__


Expand Down Expand Up @@ -66,7 +67,7 @@ def make_parse_result(
"""Make a ParseResult from the project configuration and the profile."""
# if any of these change, we need to reject the parser
vars_hash = FileHash.from_contents(
'\0'.join([
'\x00'.join([
getattr(config.args, 'vars', '{}') or '{}',
getattr(config.args, 'profile', '') or '',
getattr(config.args, 'target', '') or '',
Expand Down Expand Up @@ -305,16 +306,21 @@ def process_manifest(self, manifest: Manifest):
process_docs(manifest, self.root_project)

def create_manifest(self) -> Manifest:
# before we do anything else, patch the sources. This mutates
# results.disabled, so it needs to come before the final 'disabled'
# list is created
sources = patch_sources(self.results, self.root_project)
disabled = []
for value in self.results.disabled.values():
disabled.extend(value)

nodes: MutableMapping[str, NonSourceNode] = {
k: v for k, v in self.results.nodes.items()
}

manifest = Manifest(
nodes=nodes,
sources=self.results.sources,
sources=sources,
macros=self.results.macros,
docs=self.results.docs,
generated_at=datetime.utcnow(),
Expand Down
22 changes: 14 additions & 8 deletions core/dbt/parser/results.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from dataclasses import dataclass, field
from typing import TypeVar, MutableMapping, Mapping, Union, List, Tuple
from typing import TypeVar, MutableMapping, Mapping, Union, List

from hologram import JsonSchemaMixin

from dbt.contracts.graph.manifest import (
SourceFile, RemoteFile, FileHash, MacroKey
SourceFile, RemoteFile, FileHash, MacroKey, SourceKey
)
from dbt.contracts.graph.compiled import CompileResultNode
from dbt.contracts.graph.parsed import (
Expand All @@ -21,7 +21,7 @@
ParsedSeedNode,
ParsedSchemaTestNode,
ParsedSnapshotNode,
ParsedSourceDefinition,
UnpatchedSourceDefinition,
)
from dbt.contracts.graph.unparsed import SourcePatch
from dbt.contracts.util import Writable, Replaceable
Expand Down Expand Up @@ -68,12 +68,12 @@ class ParseResult(JsonSchemaMixin, Writable, Replaceable):
profile_hash: FileHash
project_hashes: MutableMapping[str, FileHash]
nodes: MutableMapping[str, ManifestNodes] = dict_field()
sources: MutableMapping[str, ParsedSourceDefinition] = dict_field()
sources: MutableMapping[str, UnpatchedSourceDefinition] = dict_field()
docs: MutableMapping[str, ParsedDocumentation] = dict_field()
macros: MutableMapping[str, ParsedMacro] = dict_field()
macro_patches: MutableMapping[MacroKey, ParsedMacroPatch] = dict_field()
patches: MutableMapping[str, ParsedNodePatch] = dict_field()
source_patches: MutableMapping[Tuple[str, str], SourcePatch] = dict_field()
source_patches: MutableMapping[SourceKey, SourcePatch] = dict_field()
files: MutableMapping[str, SourceFile] = dict_field()
disabled: MutableMapping[str, List[CompileResultNode]] = dict_field()
dbt_version: str = __version__
Expand All @@ -87,24 +87,30 @@ def get_file(self, source_file: SourceFile) -> SourceFile:
return self.files[key]

def add_source(
self, source_file: SourceFile, source: ParsedSourceDefinition
self, source_file: SourceFile, source: UnpatchedSourceDefinition
):
# sources can't be overwritten!
_check_duplicates(source, self.sources)
self.sources[source.unique_id] = source
self.get_file(source_file).sources.append(source.unique_id)

def add_node(self, source_file: SourceFile, node: ManifestNodes):
def add_node_nofile(self, node: ManifestNodes):
# nodes can't be overwritten!
_check_duplicates(node, self.nodes)
self.nodes[node.unique_id] = node

def add_node(self, source_file: SourceFile, node: ManifestNodes):
self.add_node_nofile(node)
self.get_file(source_file).nodes.append(node.unique_id)

def add_disabled(self, source_file: SourceFile, node: CompileResultNode):
def add_disabled_nofile(self, node: CompileResultNode):
if node.unique_id in self.disabled:
self.disabled[node.unique_id].append(node)
else:
self.disabled[node.unique_id] = [node]

def add_disabled(self, source_file: SourceFile, node: CompileResultNode):
self.add_disabled_nofile(node)
self.get_file(source_file).nodes.append(node.unique_id)

def add_macro(self, source_file: SourceFile, macro: ParsedMacro):
Expand Down
Loading

0 comments on commit 38443cf

Please sign in to comment.