-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
Copy pathmodels.py
535 lines (452 loc) · 22.3 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
from copy import deepcopy
from dbt.context.context_config import ContextConfig
from dbt.contracts.graph.parsed import ParsedModelNode
import dbt.flags as flags
from dbt.events.functions import fire_event
from dbt.events.types import (
StaticParserCausedJinjaRendering,
UsingExperimentalParser,
SampleFullJinjaRendering,
StaticParserFallbackJinjaRendering,
StaticParsingMacroOverrideDetected,
StaticParserSuccess,
StaticParserFailure,
ExperimentalParserSuccess,
ExperimentalParserFailure,
)
from dbt.node_types import NodeType, ModelLanguage
from dbt.parser.base import SimpleSQLParser
from dbt.parser.search import FileBlock
from dbt.clients.jinja import get_rendered
import dbt.tracking as tracking
from dbt import utils
from dbt_extractor import ExtractionError, py_extract_from_source # type: ignore
from functools import reduce
from itertools import chain
import random
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
# New for Python models :p
import ast
from dbt.dataclass_schema import ValidationError
from dbt.exceptions import ParsingException, validator_error_message, UndefinedMacroException
class PythonValidationVisitor(ast.NodeVisitor):
def __init__(self):
super().__init__()
self.dbt_errors = []
self.num_model_def = 0
def visit_FunctionDef(self, node: ast.FunctionDef) -> None:
if node.name == "model":
self.num_model_def += 1
if not node.args.args[0].arg == "dbt":
self.dbt_errors.append("'dbt' not provided for model as the first argument")
if len(node.args.args) != 2:
self.dbt_errors.append(
"model function should have two args, `dbt` and a session to current warehouse"
)
# check we have a return and only one
if not isinstance(node.body[-1], ast.Return) or isinstance(
node.body[-1].value, ast.Tuple
):
self.dbt_errors.append(
"In current version, model function should return only one dataframe object"
)
def check_error(self, node):
if self.num_model_def != 1:
raise ParsingException("dbt only allow one model defined per python file", node=node)
if len(self.dbt_errors) != 0:
raise ParsingException("\n".join(self.dbt_errors), node=node)
class PythonParseVisitor(ast.NodeVisitor):
def __init__(self, dbt_node):
super().__init__()
self.dbt_node = dbt_node
self.dbt_function_calls = []
self.packages = []
@classmethod
def _flatten_attr(cls, node):
if isinstance(node, ast.Attribute):
return str(cls._flatten_attr(node.value)) + "." + node.attr
elif isinstance(node, ast.Name):
return str(node.id)
else:
pass
def _safe_eval(self, node):
try:
return ast.literal_eval(node)
except (SyntaxError, ValueError, TypeError) as exc:
msg = validator_error_message(exc)
raise ParsingException(msg, node=self.dbt_node) from exc
except (MemoryError, RecursionError) as exc:
msg = validator_error_message(exc)
raise ParsingException(msg, node=self.dbt_node) from exc
def _get_call_literals(self, node):
# List of literals
arg_literals = []
kwarg_literals = {}
# TODO : Make sure this throws (and that we catch it)
# for non-literal inputs
for arg in node.args:
rendered = self._safe_eval(arg)
arg_literals.append(rendered)
for keyword in node.keywords:
key = keyword.arg
rendered = self._safe_eval(keyword.value)
kwarg_literals[key] = rendered
return arg_literals, kwarg_literals
def visit_Call(self, node: ast.Call) -> None:
func_name = self._flatten_attr(node.func)
if func_name in ["dbt.ref", "dbt.source", "dbt.config", "dbt.config.get"]:
# drop the dot-dbt prefix
func_name = func_name.split(".")[-1]
args, kwargs = self._get_call_literals(node)
self.dbt_function_calls.append((func_name, args, kwargs))
def visit_Import(self, node: ast.Import) -> None:
for n in node.names:
self.packages.append(n.name.split(".")[0])
def visit_ImportFrom(self, node: ast.ImportFrom) -> None:
if node.module:
self.packages.append(node.module.split(".")[0])
def merge_packages(original_packages_with_version, new_packages):
original_packages = [package.split("==")[0] for package in original_packages_with_version]
additional_packages = [package for package in new_packages if package not in original_packages]
return original_packages_with_version + list(set(additional_packages))
def verify_python_model_code(node):
# TODO: add a test for this
try:
rendered_python = get_rendered(
node.raw_code,
{},
node,
)
if rendered_python != node.raw_code:
raise ParsingException("")
except (UndefinedMacroException, ParsingException):
raise ParsingException("No jinja in python model code is allowed", node=node)
class ModelParser(SimpleSQLParser[ParsedModelNode]):
def parse_from_dict(self, dct, validate=True) -> ParsedModelNode:
if validate:
ParsedModelNode.validate(dct)
return ParsedModelNode.from_dict(dct)
@property
def resource_type(self) -> NodeType:
return NodeType.Model
@classmethod
def get_compiled_path(cls, block: FileBlock):
return block.path.relative_path
def parse_python_model(self, node, config, context):
try:
tree = ast.parse(node.raw_code, filename=node.original_file_path)
except SyntaxError as exc:
msg = validator_error_message(exc)
raise ParsingException(f"{msg}\n{exc.text}", node=node) from exc
# We are doing a validator and a parser because visit_FunctionDef in parser
# would actually make the parser not doing the visit_Calls any more
dbtValidator = PythonValidationVisitor()
dbtValidator.visit(tree)
dbtValidator.check_error(node)
dbtParser = PythonParseVisitor(node)
dbtParser.visit(tree)
for (func, args, kwargs) in dbtParser.dbt_function_calls:
# TODO decide what we want to do with detected packages
# if func == "config":
# kwargs["detected_packages"] = dbtParser.packages
if func == "get":
context["config"](utilized=args)
continue
context[func](*args, **kwargs)
def render_update(self, node: ParsedModelNode, config: ContextConfig) -> None:
self.manifest._parsing_info.static_analysis_path_count += 1
if node.language == ModelLanguage.python:
try:
verify_python_model_code(node)
context = self._context_for(node, config)
self.parse_python_model(node, config, context)
self.update_parsed_node_config(node, config, context=context)
except ValidationError as exc:
# we got a ValidationError - probably bad types in config()
msg = validator_error_message(exc)
raise ParsingException(msg, node=node) from exc
return
elif not flags.STATIC_PARSER:
# jinja rendering
super().render_update(node, config)
fire_event(StaticParserCausedJinjaRendering(path=node.path))
return
# only sample for experimental parser correctness on normal runs,
# not when the experimental parser flag is on.
exp_sample: bool = False
# sampling the stable static parser against jinja is significantly
# more expensive and therefore done far less frequently.
stable_sample: bool = False
# there are two samples above, and it is perfectly fine if both happen
# at the same time. If that happens, the experimental parser, stable
# parser, and jinja rendering will run on the same model file and
# send back codes for experimental v stable, and stable v jinja.
if not flags.USE_EXPERIMENTAL_PARSER:
# `True` roughly 1/5000 times this function is called
# sample = random.randint(1, 5001) == 5000
stable_sample = random.randint(1, 5001) == 5000
# sampling the experimental parser is explicitly disabled here, but use the following
# commented code to sample a fraction of the time when new
# experimental features are added.
# `True` roughly 1/100 times this function is called
# exp_sample = random.randint(1, 101) == 100
# top-level declaration of variables
statically_parsed: Optional[Union[str, Dict[str, List[Any]]]] = None
experimental_sample: Optional[Union[str, Dict[str, List[Any]]]] = None
exp_sample_node: Optional[ParsedModelNode] = None
exp_sample_config: Optional[ContextConfig] = None
jinja_sample_node: Optional[ParsedModelNode] = None
jinja_sample_config: Optional[ContextConfig] = None
result: List[str] = []
# sample the experimental parser only during a normal run
if exp_sample and not flags.USE_EXPERIMENTAL_PARSER:
fire_event(UsingExperimentalParser(path=node.path))
experimental_sample = self.run_experimental_parser(node)
# if the experimental parser succeeded, make a full copy of model parser
# and populate _everything_ into it so it can be compared apples-to-apples
# with a fully jinja-rendered project. This is necessary because the experimental
# parser will likely add features that the existing static parser will fail on
# so comparing those directly would give us bad results. The comparison will be
# conducted after this model has been fully rendered either by the static parser
# or by full jinja rendering
if isinstance(experimental_sample, dict):
model_parser_copy = self.partial_deepcopy()
exp_sample_node = deepcopy(node)
exp_sample_config = deepcopy(config)
model_parser_copy.populate(exp_sample_node, exp_sample_config, experimental_sample)
# use the experimental parser exclusively if the flag is on
if flags.USE_EXPERIMENTAL_PARSER:
statically_parsed = self.run_experimental_parser(node)
# run the stable static parser unless it is explicitly turned off
else:
statically_parsed = self.run_static_parser(node)
# if the static parser succeeded, extract some data in easy-to-compare formats
if isinstance(statically_parsed, dict):
# only sample jinja for the purpose of comparing with the stable static parser
# if we know we don't need to fall back to jinja (i.e. - nothing to compare
# with jinja v jinja).
# This means we skip sampling for 40% of the 1/5000 samples. We could run the
# sampling rng here, but the effect would be the same since we would only roll
# it 40% of the time. So I've opted to keep all the rng code colocated above.
if stable_sample and not flags.USE_EXPERIMENTAL_PARSER:
fire_event(SampleFullJinjaRendering(path=node.path))
# if this will _never_ mutate anything `self` we could avoid these deep copies,
# but we can't really guarantee that going forward.
model_parser_copy = self.partial_deepcopy()
jinja_sample_node = deepcopy(node)
jinja_sample_config = deepcopy(config)
# rendering mutates the node and the config
super(ModelParser, model_parser_copy).render_update(
jinja_sample_node, jinja_sample_config
)
# update the unrendered config with values from the static parser.
# values from yaml files are in there already
self.populate(node, config, statically_parsed)
# if we took a jinja sample, compare now that the base node has been populated
if jinja_sample_node is not None and jinja_sample_config is not None:
result = _get_stable_sample_result(
jinja_sample_node, jinja_sample_config, node, config
)
# if we took an experimental sample, compare now that the base node has been populated
if exp_sample_node is not None and exp_sample_config is not None:
result = _get_exp_sample_result(
exp_sample_node,
exp_sample_config,
node,
config,
)
self.manifest._parsing_info.static_analysis_parsed_path_count += 1
# if the static parser didn't succeed, fall back to jinja
else:
# jinja rendering
super().render_update(node, config)
fire_event(StaticParserFallbackJinjaRendering(path=node.path))
# if sampling, add the correct messages for tracking
if exp_sample and isinstance(experimental_sample, str):
if experimental_sample == "cannot_parse":
result += ["01_experimental_parser_cannot_parse"]
elif experimental_sample == "has_banned_macro":
result += ["08_has_banned_macro"]
elif stable_sample and isinstance(statically_parsed, str):
if statically_parsed == "cannot_parse":
result += ["81_stable_parser_cannot_parse"]
elif statically_parsed == "has_banned_macro":
result += ["88_has_banned_macro"]
# only send the tracking event if there is at least one result code
if result:
# fire a tracking event. this fires one event for every sample
# so that we have data on a per file basis. Not only can we expect
# no false positives or misses, we can expect the number model
# files parseable by the experimental parser to match our internal
# testing.
if tracking.active_user is not None: # None in some tests
tracking.track_experimental_parser_sample(
{
"project_id": self.root_project.hashed_name(),
"file_id": utils.get_hash(node),
"status": result,
}
)
def run_static_parser(
self, node: ParsedModelNode
) -> Optional[Union[str, Dict[str, List[Any]]]]:
# if any banned macros have been overridden by the user, we cannot use the static parser.
if self._has_banned_macro(node):
# this log line is used for integration testing. If you change
# the code at the beginning of the line change the tests in
# test/integration/072_experimental_parser_tests/test_all_experimental_parser.py
fire_event(StaticParsingMacroOverrideDetected(path=node.path))
return "has_banned_macro"
# run the stable static parser and return the results
try:
statically_parsed = py_extract_from_source(node.raw_code)
fire_event(StaticParserSuccess(path=node.path))
return _shift_sources(statically_parsed)
# if we want information on what features are barring the static
# parser from reading model files, this is where we would add that
# since that information is stored in the `ExtractionError`.
except ExtractionError:
fire_event(StaticParserFailure(path=node.path))
return "cannot_parse"
def run_experimental_parser(
self, node: ParsedModelNode
) -> Optional[Union[str, Dict[str, List[Any]]]]:
# if any banned macros have been overridden by the user, we cannot use the static parser.
if self._has_banned_macro(node):
# this log line is used for integration testing. If you change
# the code at the beginning of the line change the tests in
# test/integration/072_experimental_parser_tests/test_all_experimental_parser.py
fire_event(StaticParsingMacroOverrideDetected(path=node.path))
return "has_banned_macro"
# run the experimental parser and return the results
try:
# for now, this line calls the stable static parser since there are no
# experimental features. Change `py_extract_from_source` to the new
# experimental call when we add additional features.
experimentally_parsed = py_extract_from_source(node.raw_code)
fire_event(ExperimentalParserSuccess(path=node.path))
return _shift_sources(experimentally_parsed)
# if we want information on what features are barring the experimental
# parser from reading model files, this is where we would add that
# since that information is stored in the `ExtractionError`.
except ExtractionError:
fire_event(ExperimentalParserFailure(path=node.path))
return "cannot_parse"
# checks for banned macros
def _has_banned_macro(self, node: ParsedModelNode) -> bool:
# first check if there is a banned macro defined in scope for this model file
root_project_name = self.root_project.project_name
project_name = node.package_name
banned_macros = ["ref", "source", "config"]
all_banned_macro_keys: Iterator[str] = chain.from_iterable(
map(
lambda name: [f"macro.{project_name}.{name}", f"macro.{root_project_name}.{name}"],
banned_macros,
)
)
return reduce(
lambda z, key: z or (key in self.manifest.macros), all_banned_macro_keys, False
)
# this method updates the model node rendered and unrendered config as well
# as the node object. Used to populate these values when circumventing jinja
# rendering like the static parser.
def populate(
self, node: ParsedModelNode, config: ContextConfig, statically_parsed: Dict[str, Any]
):
# manually fit configs in
config._config_call_dict = _get_config_call_dict(statically_parsed)
# if there are hooks present this, it WILL render jinja. Will need to change
# when the experimental parser supports hooks
self.update_parsed_node_config(node, config)
# update the unrendered config with values from the file.
# values from yaml files are in there already
node.unrendered_config.update(dict(statically_parsed["configs"]))
# set refs and sources on the node object
node.refs += statically_parsed["refs"]
node.sources += statically_parsed["sources"]
# configs don't need to be merged into the node because they
# are read from config._config_call_dict
# the manifest is often huge so this method avoids deepcopying it
def partial_deepcopy(self):
return ModelParser(deepcopy(self.project), self.manifest, deepcopy(self.root_project))
# pure function. safe to use elsewhere, but unlikely to be useful outside this file.
def _get_config_call_dict(static_parser_result: Dict[str, Any]) -> Dict[str, Any]:
config_call_dict: Dict[str, Any] = {}
for c in static_parser_result["configs"]:
ContextConfig._add_config_call(config_call_dict, {c[0]: c[1]})
return config_call_dict
# TODO if we format sources in the extractor to match this type, we won't need this function.
def _shift_sources(static_parser_result: Dict[str, List[Any]]) -> Dict[str, List[Any]]:
shifted_result = deepcopy(static_parser_result)
source_calls = []
for s in static_parser_result["sources"]:
source_calls.append([s[0], s[1]])
shifted_result["sources"] = source_calls
return shifted_result
# returns a list of string codes to be sent as a tracking event
def _get_exp_sample_result(
sample_node: ParsedModelNode,
sample_config: ContextConfig,
node: ParsedModelNode,
config: ContextConfig,
) -> List[str]:
result: List[Tuple[int, str]] = _get_sample_result(sample_node, sample_config, node, config)
def process(codemsg):
code, msg = codemsg
return f"0{code}_experimental_{msg}"
return list(map(process, result))
# returns a list of string codes to be sent as a tracking event
def _get_stable_sample_result(
sample_node: ParsedModelNode,
sample_config: ContextConfig,
node: ParsedModelNode,
config: ContextConfig,
) -> List[str]:
result: List[Tuple[int, str]] = _get_sample_result(sample_node, sample_config, node, config)
def process(codemsg):
code, msg = codemsg
return f"8{code}_stable_{msg}"
return list(map(process, result))
# returns a list of string codes that need a single digit prefix to be prepended
# before being sent as a tracking event
def _get_sample_result(
sample_node: ParsedModelNode,
sample_config: ContextConfig,
node: ParsedModelNode,
config: ContextConfig,
) -> List[Tuple[int, str]]:
result: List[Tuple[int, str]] = []
# look for false positive configs
for k in sample_config._config_call_dict.keys():
if k not in config._config_call_dict.keys():
result += [(2, "false_positive_config_value")]
break
# look for missed configs
for k in config._config_call_dict.keys():
if k not in sample_config._config_call_dict.keys():
result += [(3, "missed_config_value")]
break
# look for false positive sources
for s in sample_node.sources:
if s not in node.sources:
result += [(4, "false_positive_source_value")]
break
# look for missed sources
for s in node.sources:
if s not in sample_node.sources:
result += [(5, "missed_source_value")]
break
# look for false positive refs
for r in sample_node.refs:
if r not in node.refs:
result += [(6, "false_positive_ref_value")]
break
# look for missed refs
for r in node.refs:
if r not in sample_node.refs:
result += [(7, "missed_ref_value")]
break
# if there are no errors, return a success value
if not result:
result = [(0, "exact_match")]
return result