Skip to content

Commit

Permalink
Add minimal validation of schema file yaml prior to partial parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank committed Jun 15, 2021
1 parent 14507a2 commit ee905b1
Show file tree
Hide file tree
Showing 5 changed files with 81 additions and 25 deletions.
32 changes: 29 additions & 3 deletions core/dbt/parser/read_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
FilePath, ParseFileType, SourceFile, FileHash, AnySourceFile, SchemaSourceFile
)

from dbt.parser.schemas import yaml_from_file
from dbt.parser.schemas import yaml_from_file, schema_file_keys, check_format_version
from dbt.exceptions import CompilationException
from dbt.parser.search import FilesystemSearcher


Expand All @@ -17,11 +18,36 @@ def load_source_file(
source_file = sf_cls(path=path, checksum=checksum,
parse_file_type=parse_file_type, project_name=project_name)
source_file.contents = file_contents.strip()
if parse_file_type == ParseFileType.Schema:
source_file.dfy = yaml_from_file(source_file)
if parse_file_type == ParseFileType.Schema and source_file.contents:
dfy = yaml_from_file(source_file)
validate_yaml(source_file.path.original_file_path, dfy)
source_file.dfy = dfy
return source_file


# Do some minimal validation of the yaml in a schema file.
# Check version, that key values are lists and that each element in
# the lists has a 'name' key
def validate_yaml(file_path, dct):
check_format_version(file_path, dct)
for key in schema_file_keys:
if key in dct:
if not isinstance(dct[key], list):
msg = (f"The schema file at {file_path} is "
f"invalid because the value of '{key}' is not a list")
raise CompilationException(msg)
for element in dct[key]:
if not isinstance(element, dict):
msg = (f"The schema file at {file_path} is "
f"invalid because a list element for '{key}' is not a dictionary")
raise CompilationException(msg)
if 'name' not in element:
msg = (f"The schema file at {file_path} is "
f"invalid because a list element for '{key}' does not have a "
"name attribute.")
raise CompilationException(msg)


# Special processing for big seed files
def load_seed_source_file(match: FilePath, project_name) -> SourceFile:
if match.seed_too_large():
Expand Down
46 changes: 24 additions & 22 deletions core/dbt/parser/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@

TestDef = Union[str, Dict[str, Any]]

schema_file_keys = (
'models', 'seeds', 'snapshots', 'sources',
'macros', 'analyses', 'exposures',
)


def error_context(
path: str,
Expand Down Expand Up @@ -200,25 +205,6 @@ def parse_from_dict(self, dct, validate=True) -> ParsedSchemaTestNode:
ParsedSchemaTestNode.validate(dct)
return ParsedSchemaTestNode.from_dict(dct)

def _check_format_version(
self, yaml: YamlBlock
) -> None:
path = yaml.path.relative_path
if 'version' not in yaml.data:
raise_invalid_schema_yml_version(path, 'no version is specified')

version = yaml.data['version']
# if it's not an integer, the version is malformed, or not
# set. Either way, only 'version: 2' is supported.
if not isinstance(version, int):
raise_invalid_schema_yml_version(
path, 'the version is not an integer'
)
if version != 2:
raise_invalid_schema_yml_version(
path, 'version {} is not supported'.format(version)
)

def parse_column_tests(
self, block: TestBlock, column: UnparsedColumn
) -> None:
Expand Down Expand Up @@ -514,9 +500,6 @@ def parse_file(self, block: FileBlock, dct: Dict = None) -> None:
# contains the FileBlock and the data (dictionary)
yaml_block = YamlBlock.from_file_block(block, dct)

# checks version
self._check_format_version(yaml_block)

parser: YamlDocsReader

# There are 7 kinds of parsers:
Expand Down Expand Up @@ -565,6 +548,25 @@ def parse_file(self, block: FileBlock, dct: Dict = None) -> None:
self.manifest.add_exposure(yaml_block.file, node)


def check_format_version(
file_path, yaml_dct
) -> None:
if 'version' not in yaml_dct:
raise_invalid_schema_yml_version(file_path, 'no version is specified')

version = yaml_dct['version']
# if it's not an integer, the version is malformed, or not
# set. Either way, only 'version: 2' is supported.
if not isinstance(version, int):
raise_invalid_schema_yml_version(
file_path, 'the version is not an integer'
)
if version != 2:
raise_invalid_schema_yml_version(
file_path, 'version {} is not supported'.format(version)
)


Parsed = TypeVar(
'Parsed',
UnpatchedSourceDefinition, ParsedNodePatch, ParsedMacroPatch
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
select 1 as "Id"
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
version: 2

models:
name: model
columns:
- name: Id
quote: true
tests:
- unique
- not_null
17 changes: 17 additions & 0 deletions test/integration/008_schema_tests_test/test_schema_v2_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,3 +523,20 @@ def test_postgres_collision_test_names_get_hash(self):
]
self.assertIn(test_results[0].node.unique_id, expected_unique_ids)
self.assertIn(test_results[1].node.unique_id, expected_unique_ids)


class TestInvalidSchema(DBTIntegrationTest):
@property
def schema(self):
return "schema_tests_008"

@property
def models(self):
return "invalid-schema-models"

@use_profile('postgres')
def test_postgres_invalid_schema_file(self):
with self.assertRaises(CompilationException) as exc:
results = self.run_dbt()
self.assertRegex(str(exc.exception), r"'models' is not a list")

0 comments on commit ee905b1

Please sign in to comment.