diff --git a/CHANGELOG.md b/CHANGELOG.md index 37be1626256..252490afbf9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,7 @@ ### Features - Add a "docs" field to models, with a "show" subfield ([#1671](https://github.com/fishtown-analytics/dbt/issues/1671), [#2107](https://github.com/fishtown-analytics/dbt/pull/2107)) - Add a dbt-{dbt_version} user agent field to the bigquery connector ([#2121](https://github.com/fishtown-analytics/dbt/issues/2121), [#2146](https://github.com/fishtown-analytics/dbt/pull/2146)) +- Add support for generating database name macro ([#1695](https://github.com/fishtown-analytics/dbt/issues/1695), [#2143](https://github.com/fishtown-analytics/dbt/pull/2143)) ### Fixes - Fix issue where dbt did not give an error in the presence of duplicate doc names ([#2054](https://github.com/fishtown-analytics/dbt/issues/2054), [#2080](https://github.com/fishtown-analytics/dbt/pull/2080)) diff --git a/core/dbt/contracts/graph/manifest.py b/core/dbt/contracts/graph/manifest.py index 00624b83772..13599366e1e 100644 --- a/core/dbt/contracts/graph/manifest.py +++ b/core/dbt/contracts/graph/manifest.py @@ -541,21 +541,21 @@ def filter(candidate: MacroCandidate) -> bool: return candidates.last() def find_generate_macro_by_name( - self, name: str, root_project_name: str + self, component: str, root_project_name: str ) -> Optional[ParsedMacro]: """ The `generate_X_name` macros are similar to regular ones, but ignore imported packages. - - if there is a `name` macro in the root project, return it - - if that does not exist but there is a `name` macro in the 'dbt' - internal project (or a plugin), return that - - if neither of those exist (unit tests?), return None + - if there is a `generate_{component}_name` macro in the root + project, return it + - return the `generate_{component}_name` macro from the 'dbt' + internal project """ def filter(candidate: MacroCandidate) -> bool: return candidate.locality != Locality.Imported candidates: CandidateList = self._find_macros_by_name( - name=name, + name=f'generate_{component}_name', root_project_name=root_project_name, # filter out imported packages filter=filter, diff --git a/core/dbt/include/global_project/macros/etc/get_custom_database.sql b/core/dbt/include/global_project/macros/etc/get_custom_database.sql new file mode 100644 index 00000000000..adbe152cb9b --- /dev/null +++ b/core/dbt/include/global_project/macros/etc/get_custom_database.sql @@ -0,0 +1,28 @@ +{# + Renders a database name given a custom database name. If the custom + database name is none, then the resulting database is just the "database" + value in the specified target. If a database override is specified, then + the resulting database is the default database concatenated with the + custom database. + + This macro can be overriden in projects to define different semantics + for rendering a database name. + + Arguments: + custom_database_name: The custom database name specified for a model, or none + node: The node the database is being generated for + +#} +{% macro generate_database_name(custom_database_name=none, node=none) -%} + {%- set default_database = target.database -%} + {%- if custom_database_name is none -%} + + {{ default_database }} + + {%- else -%} + + {{ custom_database_name }} + + {%- endif -%} + +{%- endmacro %} diff --git a/core/dbt/parser/base.py b/core/dbt/parser/base.py index c90eef150c6..48bc294d17b 100644 --- a/core/dbt/parser/base.py +++ b/core/dbt/parser/base.py @@ -2,7 +2,7 @@ import itertools import os from typing import ( - List, Dict, Any, Callable, Iterable, Optional, Generic, TypeVar + List, Dict, Any, Iterable, Generic, TypeVar ) from hologram import ValidationError @@ -11,7 +11,6 @@ from dbt.clients.system import load_file_contents from dbt.context.providers import generate_parser_model, generate_parser_macro import dbt.flags -from dbt import deprecations from dbt import hooks from dbt.adapters.factory import get_adapter from dbt.clients.jinja import get_rendered @@ -19,10 +18,10 @@ from dbt.contracts.graph.manifest import ( Manifest, SourceFile, FilePath, FileHash ) -from dbt.contracts.graph.parsed import HasUniqueID, ParsedMacro +from dbt.contracts.graph.parsed import HasUniqueID from dbt.contracts.graph.unparsed import UnparsedNode from dbt.exceptions import ( - CompilationException, validator_error_message + CompilationException, validator_error_message, InternalException ) from dbt.node_types import NodeType from dbt.source_config import SourceConfig @@ -39,7 +38,6 @@ FinalNode = TypeVar('FinalNode', bound=ManifestNodes) -RelationUpdate = Callable[[Optional[str], IntermediateNode], str] ConfiguredBlockType = TypeVar('ConfiguredBlockType', bound=FileBlock) @@ -94,6 +92,31 @@ def __init__( self.macro_manifest = macro_manifest +class RelationUpdate: + def __init__( + self, config: RuntimeConfig, manifest: Manifest, component: str + ) -> None: + macro = manifest.find_generate_macro_by_name( + component=component, + root_project_name=config.project_name, + ) + if macro is None: + raise InternalException( + f'No macro with name generate_{component}_name found' + ) + + root_context = generate_parser_macro(macro, config, manifest, None) + self.updater = MacroGenerator(macro, root_context) + self.component = component + + def __call__( + self, parsed_node: Any, config_dict: Dict[str, Any] + ) -> None: + override = config_dict.get(self.component) + new_value = self.updater(override, parsed_node).strip() + setattr(parsed_node, self.component, new_value) + + class ConfiguredParser( Parser[FinalNode], Generic[ConfiguredBlockType, IntermediateNode, FinalNode], @@ -106,8 +129,16 @@ def __init__( macro_manifest: Manifest, ) -> None: super().__init__(results, project, root_project, macro_manifest) - self._get_schema_func: Optional[RelationUpdate] = None - self._get_alias_func: Optional[RelationUpdate] = None + + self._update_node_database = RelationUpdate( + manifest=macro_manifest, config=root_project, component='database' + ) + self._update_node_schema = RelationUpdate( + manifest=macro_manifest, config=root_project, component='schema' + ) + self._update_node_alias = RelationUpdate( + manifest=macro_manifest, config=root_project, component='alias' + ) @abc.abstractclassmethod def get_compiled_path(cls, block: ConfiguredBlockType) -> str: @@ -129,69 +160,6 @@ def default_schema(self): def default_database(self): return self.root_project.credentials.database - def _build_generate_macro_function(self, macro: ParsedMacro) -> Callable: - root_context = generate_parser_macro( - macro, self.root_project, self.macro_manifest, None - ) - return MacroGenerator(macro, root_context) - - def get_schema_func(self) -> RelationUpdate: - """The get_schema function is set by a few different things: - - if there is a 'generate_schema_name' macro in the root project, - it will be used. - - if that does not exist but there is a 'generate_schema_name' - macro in the 'dbt' internal project, that will be used - - if neither of those exist (unit tests?), a function that returns - the 'default schema' as set in the root project's 'credentials' - is used - """ - if self._get_schema_func is not None: - return self._get_schema_func - - get_schema_macro = self.macro_manifest.find_generate_macro_by_name( - name='generate_schema_name', - root_project_name=self.root_project.project_name, - ) - # this is only true in tests! - if get_schema_macro is None: - def get_schema(custom_schema_name=None, node=None): - return self.default_schema - else: - get_schema = self._build_generate_macro_function(get_schema_macro) - - self._get_schema_func = get_schema - return self._get_schema_func - - def get_alias_func(self) -> RelationUpdate: - """The get_alias function is set by a few different things: - - if there is a 'generate_alias_name' macro in the root project, - it will be used. - - if that does not exist but there is a 'generate_alias_name' - macro in the 'dbt' internal project, that will be used - - if neither of those exist (unit tests?), a function that returns - the 'default alias' as set in the model's filename or alias - configuration. - """ - if self._get_alias_func is not None: - return self._get_alias_func - - get_alias_macro = self.macro_manifest.find_generate_macro_by_name( - name='generate_alias_name', - root_project_name=self.root_project.project_name, - ) - # the generate_alias_name macro might not exist - if get_alias_macro is None: - def get_alias(custom_alias_name, node): - if custom_alias_name is None: - return node.name - else: - return custom_alias_name - else: - get_alias = self._build_generate_macro_function(get_alias_macro) - - self._get_alias_func = get_alias - return self._get_alias_func - def get_fqn(self, path: str, name: str) -> List[str]: """Get the FQN for the node. This impacts node selection and config application. @@ -297,33 +265,6 @@ def render_with_context( parsed_node.raw_sql, context, parsed_node, capture_macros=True ) - def update_parsed_node_schema( - self, parsed_node: IntermediateNode, config_dict: Dict[str, Any] - ) -> None: - # Special macro defined in the global project. Use the root project's - # definition, not the current package - schema_override = config_dict.get('schema') - get_schema = self.get_schema_func() - try: - schema = get_schema(schema_override, parsed_node) - except dbt.exceptions.CompilationException as exc: - too_many_args = ( - "macro 'dbt_macro__generate_schema_name' takes not more than " - "1 argument(s)" - ) - if too_many_args not in str(exc): - raise - deprecations.warn('generate-schema-name-single-arg') - schema = get_schema(schema_override) # type: ignore - parsed_node.schema = schema.strip() - - def update_parsed_node_alias( - self, parsed_node: IntermediateNode, config_dict: Dict[str, Any] - ) -> None: - alias_override = config_dict.get('alias') - get_alias = self.get_alias_func() - parsed_node.alias = get_alias(alias_override, parsed_node).strip() - def update_parsed_node_config( self, parsed_node: IntermediateNode, config_dict: Dict[str, Any] ) -> None: @@ -334,6 +275,13 @@ def update_parsed_node_config( self._mangle_hooks(final_config_dict) parsed_node.config = parsed_node.config.from_dict(final_config_dict) + def update_parsed_node_name( + self, parsed_node: IntermediateNode, config_dict: Dict[str, Any] + ) -> None: + self._update_node_database(parsed_node, config_dict) + self._update_node_schema(parsed_node, config_dict) + self._update_node_alias(parsed_node, config_dict) + def update_parsed_node( self, parsed_node: IntermediateNode, config: SourceConfig ) -> None: @@ -347,15 +295,10 @@ def update_parsed_node( model_tags = config_dict.get('tags', []) parsed_node.tags.extend(model_tags) - # do this once before we parse the node schema/alias, so + # do this once before we parse the node database/schema/alias, so # parsed_node.config is what it would be if they did nothing self.update_parsed_node_config(parsed_node, config_dict) - - parsed_node.database = config_dict.get( - 'database', self.default_database - ).strip() - self.update_parsed_node_schema(parsed_node, config_dict) - self.update_parsed_node_alias(parsed_node, config_dict) + self.update_parsed_node_name(parsed_node, config_dict) # at this point, we've collected our hooks. Use the node context to # render each hook and collect refs/sources diff --git a/test/integration/006_simple_dependency_test/local_dependency/macros/generate_schema_name.sql b/test/integration/006_simple_dependency_test/local_dependency/macros/generate_schema_name.sql index 321b88147d4..127ba8c5575 100644 --- a/test/integration/006_simple_dependency_test/local_dependency/macros/generate_schema_name.sql +++ b/test/integration/006_simple_dependency_test/local_dependency/macros/generate_schema_name.sql @@ -1,4 +1,15 @@ {# This should be ignored as it's in a subpackage #} -{% macro generate_schema_name(custom_schema_name=none) -%} - invalid_schema_name +{% macro generate_schema_name(custom_schema_name=none, node=none) -%} + {{ exceptions.raise_compiler_error('invalid', node=node) }} +{%- endmacro %} + +{# This should be ignored as it's in a subpackage #} +{% macro generate_database_name(custom_database_name=none, node=none) -%} + {{ exceptions.raise_compiler_error('invalid', node=node) }} +{%- endmacro %} + + +{# This should be ignored as it's in a subpackage #} +{% macro generate_alias_name(custom_alias_name=none, node=none) -%} + {{ exceptions.raise_compiler_error('invalid', node=node) }} {%- endmacro %} diff --git a/test/integration/012_deprecation_tests/deprecated-macros/schema.sql b/test/integration/012_deprecation_tests/deprecated-macros/schema.sql deleted file mode 100644 index d3884257ad6..00000000000 --- a/test/integration/012_deprecation_tests/deprecated-macros/schema.sql +++ /dev/null @@ -1,7 +0,0 @@ -{% macro generate_schema_name(schema_name) -%} - {%- if schema_name is none -%} - {{ target.schema }} - {%- else -%} - {{ schema_name }} - {%- endif -%} -{%- endmacro %} diff --git a/test/integration/012_deprecation_tests/test_deprecations.py b/test/integration/012_deprecation_tests/test_deprecations.py index 70af3f2c5d2..8e0231461ad 100644 --- a/test/integration/012_deprecation_tests/test_deprecations.py +++ b/test/integration/012_deprecation_tests/test_deprecations.py @@ -35,30 +35,6 @@ def test_postgres_deprecations(self): self.assertEqual(expected, deprecations.active_deprecations) -class TestMacroDeprecations(BaseTestDeprecations): - @property - def models(self): - return self.dir('boring-models') - - @property - def project_config(self): - return { - 'macro-paths': [self.dir('deprecated-macros')], - } - - @use_profile('postgres') - def test_postgres_deprecations_fail(self): - with self.assertRaises(dbt.exceptions.CompilationException): - self.run_dbt(strict=True) - - @use_profile('postgres') - def test_postgres_deprecations(self): - self.assertEqual(deprecations.active_deprecations, set()) - self.run_dbt(strict=False) - expected = {'generate-schema-name-single-arg'} - self.assertEqual(expected, deprecations.active_deprecations) - - class TestMaterializationReturnDeprecation(BaseTestDeprecations): @property def models(self): diff --git a/test/integration/024_custom_schema_test/custom-db-macros/custom_db.sql b/test/integration/024_custom_schema_test/custom-db-macros/custom_db.sql new file mode 100644 index 00000000000..bb9717490de --- /dev/null +++ b/test/integration/024_custom_schema_test/custom-db-macros/custom_db.sql @@ -0,0 +1,10 @@ + +{% macro generate_database_name(database_name, node) %} + {% if database_name == 'alt' %} + {{ env_var('SNOWFLAKE_TEST_ALT_DATABASE') }} + {% elif database_name %} + {{ database_name }} + {% else %} + {{ target.database }} + {% endif %} +{% endmacro %} diff --git a/test/integration/024_custom_schema_test/db-models/view_1.sql b/test/integration/024_custom_schema_test/db-models/view_1.sql new file mode 100644 index 00000000000..501c773e8f0 --- /dev/null +++ b/test/integration/024_custom_schema_test/db-models/view_1.sql @@ -0,0 +1,3 @@ + + +select * from {{ target.schema }}.seed diff --git a/test/integration/024_custom_schema_test/db-models/view_2.sql b/test/integration/024_custom_schema_test/db-models/view_2.sql new file mode 100644 index 00000000000..7bec9b2052f --- /dev/null +++ b/test/integration/024_custom_schema_test/db-models/view_2.sql @@ -0,0 +1,2 @@ +{{ config(database='alt') }} +select * from {{ ref('view_1') }} diff --git a/test/integration/024_custom_schema_test/db-models/view_3.sql b/test/integration/024_custom_schema_test/db-models/view_3.sql new file mode 100644 index 00000000000..825a672c59b --- /dev/null +++ b/test/integration/024_custom_schema_test/db-models/view_3.sql @@ -0,0 +1,30 @@ + +{{ config(database='alt', materialized='table') }} + + +with v1 as ( + + select * from {{ ref('view_1') }} + +), + +v2 as ( + + select * from {{ ref('view_2') }} + +), + +combined as ( + + select last_name from v1 + union all + select last_name from v2 + +) + +select + last_name, + count(*) as count + +from combined +group by 1 diff --git a/test/integration/024_custom_schema_test/seed.sql b/test/integration/024_custom_schema_test/seed.sql index 607759b0781..4962d1ea592 100644 --- a/test/integration/024_custom_schema_test/seed.sql +++ b/test/integration/024_custom_schema_test/seed.sql @@ -1,6 +1,6 @@ -drop table if exists {schema}.seed cascade; -create table {schema}.seed ( +drop table if exists {database}.{schema}.seed cascade; +create table {database}.{schema}.seed ( id BIGSERIAL PRIMARY KEY, first_name VARCHAR(50), last_name VARCHAR(50), @@ -9,17 +9,17 @@ create table {schema}.seed ( ip_address VARCHAR(20) ); -drop table if exists {schema}.agg cascade; -create table {schema}.agg ( +drop table if exists {database}.{schema}.agg cascade; +create table {database}.{schema}.agg ( last_name VARCHAR(50), count BIGINT ); -insert into {schema}.seed (first_name, last_name, email, gender, ip_address) values +insert into {database}.{schema}.seed (first_name, last_name, email, gender, ip_address) values ('Jack', 'Hunter', 'jhunter0@pbs.org', 'Male', '59.80.20.168'), ('Kathryn', 'Walker', 'kwalker1@ezinearticles.com', 'Female', '194.121.179.35'), ('Gerald', 'Ryan', 'gryan2@com.com', 'Male', '11.3.212.243'); -insert into {schema}.agg (last_name, count) values +insert into {database}.{schema}.agg (last_name, count) values ('Hunter', 2), ('Walker', 2), ('Ryan', 2); diff --git a/test/integration/024_custom_schema_test/seed/agg.csv b/test/integration/024_custom_schema_test/seed/agg.csv new file mode 100644 index 00000000000..8288f827b09 --- /dev/null +++ b/test/integration/024_custom_schema_test/seed/agg.csv @@ -0,0 +1,4 @@ +last_name,count +Hunter,2 +Walker,2 +Ryan,2 diff --git a/test/integration/024_custom_schema_test/seed/seed.csv b/test/integration/024_custom_schema_test/seed/seed.csv new file mode 100644 index 00000000000..fe72ae2e76e --- /dev/null +++ b/test/integration/024_custom_schema_test/seed/seed.csv @@ -0,0 +1,4 @@ +id,first_name,last_name,email,gender,ip_address +1,Jack,Hunter,jhunter0@pbs.org,Male,59.80.20.168 +2,Kathryn,Walker,kwalker1@ezinearticles.com,Female,194.121.179.35 +3,Gerald,Ryan,gryan2@com.com,Male,11.3.212.243 diff --git a/test/integration/024_custom_schema_test/test_custom_database.py b/test/integration/024_custom_schema_test/test_custom_database.py new file mode 100644 index 00000000000..23a80f81e5f --- /dev/null +++ b/test/integration/024_custom_schema_test/test_custom_database.py @@ -0,0 +1,38 @@ +from test.integration.base import DBTIntegrationTest, use_profile + + +class TestOverrideDatabase(DBTIntegrationTest): + setup_alternate_db = True + + @property + def schema(self): + return "custom_schema_024" + + @property + def models(self): + return "db-models" + + @property + def project_config(self): + return { + 'macro-paths': ['custom-db-macros'], + } + + @use_profile('snowflake') + def test_snowflake_override_generate_db_name(self): + self.run_sql_file('seed.sql') + self.assertTableDoesExist('SEED', schema=self.unique_schema(), database=self.default_database) + self.assertTableDoesExist('AGG', schema=self.unique_schema(), database=self.default_database) + + results = self.run_dbt() + self.assertEqual(len(results), 3) + + self.assertTableDoesExist('VIEW_1', schema=self.unique_schema(), database=self.default_database) + self.assertTableDoesExist('VIEW_2', schema=self.unique_schema(), database=self.alternative_database) + self.assertTableDoesExist('VIEW_3', schema=self.unique_schema(), database=self.alternative_database) + + # not overridden + self.assertTablesEqual('SEED', 'VIEW_1', table_b_db=self.default_database) + # overridden + self.assertTablesEqual('SEED', 'VIEW_2', table_b_db=self.alternative_database) + self.assertTablesEqual('AGG', 'VIEW_3', table_b_db=self.alternative_database) diff --git a/test/integration/base.py b/test/integration/base.py index 6b2cf545f10..36b7168170f 100644 --- a/test/integration/base.py +++ b/test/integration/base.py @@ -619,7 +619,7 @@ def run_sql_common(self, sql, fetch, conn): else: return except BaseException as e: - if conn.handle and not conn.handle.closed: + if conn.handle and not getattr(conn.handle, 'closed', True): conn.handle.rollback() print(sql) print(e) @@ -1065,16 +1065,16 @@ def _assertTableRowCountsEqual(self, relation_a, relation_b): ) ) - def assertTableDoesNotExist(self, table, schema=None): - columns = self.get_table_columns(table, schema) + def assertTableDoesNotExist(self, table, schema=None, database=None): + columns = self.get_table_columns(table, schema, database) self.assertEqual( len(columns), 0 ) - def assertTableDoesExist(self, table, schema=None): - columns = self.get_table_columns(table, schema) + def assertTableDoesExist(self, table, schema=None, database=None): + columns = self.get_table_columns(table, schema, database) self.assertGreater( len(columns), diff --git a/test/unit/test_graph.py b/test/unit/test_graph.py index 697f301b7bc..8424d3d6e94 100644 --- a/test/unit/test_graph.py +++ b/test/unit/test_graph.py @@ -11,9 +11,11 @@ import dbt.config import dbt.utils import dbt.parser.manifest -from dbt.contracts.graph.manifest import FilePath, SourceFile, FileHash +from dbt.contracts.graph.manifest import FilePath, SourceFile, FileHash, Manifest +from dbt.contracts.graph.parsed import ParsedMacro from dbt.parser.results import ParseResult from dbt.parser.base import BaseParser +from dbt.node_types import NodeType try: from queue import Empty @@ -22,7 +24,19 @@ from dbt.logger import GLOBAL_LOGGER as logger # noqa -from .utils import config_from_parts_or_dicts +from .utils import config_from_parts_or_dicts, generate_name_macros + + +def MockMacro(package, name='my_macro', kwargs={}): + macro = MagicMock( + __class__=ParsedMacro, + resource_type=NodeType.Macro, + package_name=package, + unique_id=f'macro.{package}.{name}', + **kwargs + ) + macro.name = name + return macro class GraphTest(unittest.TestCase): @@ -37,6 +51,7 @@ def tearDown(self): self.mock_hook_constructor.stop() self.load_patch.stop() self.load_source_file_patcher.stop() + # self.relation_update_patcher.stop() def setUp(self): dbt.flags.STRICT_MODE = True @@ -101,6 +116,12 @@ def _mock_parse_result(config, all_projects): self.mock_source_file = self.load_source_file_patcher.start() self.mock_source_file.side_effect = lambda path: [n for n in self.mock_models if n.path == path][0] + # self.relation_update_patcher = patch.object(RelationUpdate, '_relation_components', lambda: []) + # self.mock_relation_update = self.relation_update_patcher.start() + self.internal_manifest = Manifest.from_macros(macros={ + n.unique_id: n for n in generate_name_macros('test_models_compile') + }) + def filesystem_iter(iter_self): if 'sql' not in iter_self.extension: return [] @@ -153,7 +174,7 @@ def use_models(self, models): def load_manifest(self, config): loader = dbt.parser.manifest.ManifestLoader(config, {config.project_name: config}) - loader.load() + loader.load(internal_manifest=self.internal_manifest) return loader.create_manifest() def test__single_model(self): @@ -303,7 +324,7 @@ def test__partial_parse(self): config = self.get_config() loader = dbt.parser.manifest.ManifestLoader(config, {config.project_name: config}) - loader.load() + loader.load(internal_manifest=self.internal_manifest) loader.create_manifest() results = loader.results diff --git a/test/unit/test_manifest.py b/test/unit/test_manifest.py index 29405f529ef..a0e0ddbd3a2 100644 --- a/test/unit/test_manifest.py +++ b/test/unit/test_manifest.py @@ -633,10 +633,8 @@ def test__build_flat_graph(self): self.assertEqual(compiled_count, 2) - # Tests of the manifest search code (find_X_by_Y) - def MockMacro(package, name='my_macro', kwargs={}): macro = mock.MagicMock( __class__=ParsedMacro, @@ -656,6 +654,11 @@ def MockMaterialization(package, name='my_materialization', adapter_type=None, k return MockMacro(package, f'materialization_{name}_{adapter_type}', kwargs) +def MockGenerateMacro(package, component='some_component', kwargs={}): + name = f'generate_{component}_name' + return MockMacro(package, name=name, kwargs=kwargs) + + def MockSource(package, source_name, name, kwargs={}): src = mock.MagicMock( __class__=ParsedSourceDefinition, @@ -834,43 +837,43 @@ def test_find_macro_by_name(macros, expectations): # just root FindMacroSpec( - macros=[MockMacro('root')], + macros=[MockGenerateMacro('root')], expected='root', ), # just dep FindMacroSpec( - macros=[MockMacro('dep')], + macros=[MockGenerateMacro('dep')], expected=None, ), # just dbt FindMacroSpec( - macros=[MockMacro('dbt')], + macros=[MockGenerateMacro('dbt')], expected='dbt', ), # root overrides dep FindMacroSpec( - macros=[MockMacro('root'), MockMacro('dep')], + macros=[MockGenerateMacro('root'), MockGenerateMacro('dep')], expected='root', ), # root overrides core FindMacroSpec( - macros=[MockMacro('root'), MockMacro('dbt')], + macros=[MockGenerateMacro('root'), MockGenerateMacro('dbt')], expected='root', ), # dep overrides core FindMacroSpec( - macros=[MockMacro('dep'), MockMacro('dbt')], + macros=[MockGenerateMacro('dep'), MockGenerateMacro('dbt')], expected='dbt', ), # root overrides dep overrides core FindMacroSpec( - macros=[MockMacro('root'), MockMacro('dep'), MockMacro('dbt')], + macros=[MockGenerateMacro('root'), MockGenerateMacro('dep'), MockGenerateMacro('dbt')], expected='root', ), ] @@ -879,7 +882,9 @@ def test_find_macro_by_name(macros, expectations): @pytest.mark.parametrize('macros,expected', generate_name_parameter_sets, ids=id_macro) def test_find_generate_macro_by_name(macros, expected): manifest = make_manifest(macros=macros) - result = manifest.find_generate_macro_by_name(name='my_macro', root_project_name='root') + result = manifest.find_generate_macro_by_name( + component='some_component', root_project_name='root' + ) if expected is None: assert result is expected else: diff --git a/test/unit/test_parser.py b/test/unit/test_parser.py index 36fd5b19ccc..27dfc2dd9de 100644 --- a/test/unit/test_parser.py +++ b/test/unit/test_parser.py @@ -32,7 +32,7 @@ FreshnessThreshold, ExternalTable, Docs ) -from .utils import config_from_parts_or_dicts, normalize +from .utils import config_from_parts_or_dicts, normalize, generate_name_macros def get_abs_os_path(unix_path): @@ -42,6 +42,32 @@ def get_abs_os_path(unix_path): class BaseParserTest(unittest.TestCase): maxDiff = None + def _generate_macros(self): + name_sql = {} + for component in ('database', 'schema', 'alias'): + if component == 'alias': + source = 'node.name' + else: + source = f'target.{component}' + name = f'generate_{component}_name' + sql = f'{{% macro {name}(value, node) %}} {{% if value %}} {{{{ value }}}} {{% else %}} {{{{ {source} }}}} {{% endif %}} {{% endmacro %}}' + name_sql[name] = sql + + all_sql = '\n'.join(name_sql.values()) + for name, sql in name_sql.items(): + pm = ParsedMacro( + name=name, + resource_type=NodeType.Macro, + unique_id=f'macro.root.{name}', + package_name='root', + original_file_path=normalize('macros/macro.sql'), + root_path=get_abs_os_path('./dbt_modules/root'), + path=normalize('macros/macro.sql'), + raw_sql=all_sql, + macro_sql=sql, + ) + yield pm + def setUp(self): dbt.flags.STRICT_MODE = True dbt.flags.WARN_ERROR = True @@ -98,7 +124,9 @@ def setUp(self): self.parser_patcher = mock.patch('dbt.parser.base.get_adapter') self.factory_parser = self.parser_patcher.start() - self.macro_manifest = Manifest.from_macros() + self.macro_manifest = Manifest.from_macros( + macros={m.unique_id: m for m in generate_name_macros('root')} + ) def tearDown(self): self.parser_patcher.stop() diff --git a/test/unit/utils.py b/test/unit/utils.py index dcd859544ca..e4ed58bdeee 100644 --- a/test/unit/utils.py +++ b/test/unit/utils.py @@ -133,3 +133,32 @@ def assert_fails_validation(self, dct, cls=None): with self.assertRaises(ValidationError): cls.from_dict(dct) + + +def generate_name_macros(package): + from dbt.contracts.graph.parsed import ParsedMacro + from dbt.node_types import NodeType + name_sql = {} + for component in ('database', 'schema', 'alias'): + if component == 'alias': + source = 'node.name' + else: + source = f'target.{component}' + name = f'generate_{component}_name' + sql = f'{{% macro {name}(value, node) %}} {{% if value %}} {{{{ value }}}} {{% else %}} {{{{ {source} }}}} {{% endif %}} {{% endmacro %}}' + name_sql[name] = sql + + all_sql = '\n'.join(name_sql.values()) + for name, sql in name_sql.items(): + pm = ParsedMacro( + name=name, + resource_type=NodeType.Macro, + unique_id=f'macro.{package}.{name}', + package_name=package, + original_file_path=normalize('macros/macro.sql'), + root_path='./dbt_modules/root', + path=normalize('macros/macro.sql'), + raw_sql=all_sql, + macro_sql=sql, + ) + yield pm