Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Remove catalog from the DDL SQL generated by on_schema_change=sync_all_columns #684

Merged
merged 4 commits into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions dbt/adapters/athena/relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ class AthenaIncludePolicy(Policy):
identifier: bool = True


@dataclass
class AthenaHiveIncludePolicy(Policy):
database: bool = False
schema: bool = True
identifier: bool = True


@dataclass(frozen=True, eq=False, repr=False)
class AthenaRelation(BaseRelation):
quote_character: str = '"' # Presto quote character
Expand All @@ -42,10 +49,13 @@ def render_hive(self) -> str:
- https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL
"""

old_value = self.quote_character
old_quote_character = self.quote_character
object.__setattr__(self, "quote_character", "`") # Hive quote char
old_include_policy = self.include_policy
object.__setattr__(self, "include_policy", AthenaHiveIncludePolicy())
rendered = self.render()
object.__setattr__(self, "quote_character", old_value)
object.__setattr__(self, "quote_character", old_quote_character)
object.__setattr__(self, "include_policy", old_include_policy)
return str(rendered)

def render_pure(self) -> str:
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/adapter/test_ha_iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test__table_creation(self, project, capsys):
out, _ = capsys.readouterr()
# in case of 2nd run we expect that the target table is renamed to __bkp
alter_statement = (
f"alter table `awsdatacatalog`.`{project.test_schema}`.`{relation_name}` "
f"alter table `{project.test_schema}`.`{relation_name}` "
iconara marked this conversation as resolved.
Show resolved Hide resolved
f"rename to `{project.test_schema}`.`{relation_name}__bkp`"
)
delete_bkp_table_log = (
Expand Down
100 changes: 100 additions & 0 deletions tests/functional/adapter/test_on_schema_change.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import json

import pytest

from dbt.contracts.results import RunStatus
from dbt.tests.util import run_dbt, run_dbt_and_capture

models__table_base_model = """
{{
config(
materialized='incremental',
incremental_strategy='append',
on_schema_change=var("on_schema_change"),
table_type=var("table_type"),
)
}}

select
1 as id,
'test 1' as name
{%- if is_incremental() -%}
,current_date as updated_at
{%- endif -%}
"""


class TestOnSchemaChange:
@pytest.fixture(scope="class")
def models(self):
models = {}
for table_type in ["hive", "iceberg"]:
for on_schema_change in ["sync_all_columns", "append_new_columns", "ignore", "fail"]:
models[f"{table_type}_on_schema_change_{on_schema_change}.sql"] = models__table_base_model
return models

def _column_names(self, project, relation_name):
result = project.run_sql(f"show columns from {relation_name}", fetch="all")
column_names = [row[0].strip() for row in result]
return column_names

@pytest.mark.parametrize("table_type", ["hive", "iceberg"])
def test__sync_all_columns(self, project, table_type):
relation_name = f"{table_type}_on_schema_change_sync_all_columns"
vars = {"on_schema_change": "sync_all_columns", "table_type": table_type}
args = ["run", "--select", relation_name, "--vars", json.dumps(vars)]

model_run_initial = run_dbt(args)
assert model_run_initial.results[0].status == RunStatus.Success

model_run_incremental = run_dbt(args)
assert model_run_incremental.results[0].status == RunStatus.Success

new_column_names = self._column_names(project, relation_name)
assert new_column_names == ["id", "name", "updated_at"]

@pytest.mark.parametrize("table_type", ["hive", "iceberg"])
def test__append_new_columns(self, project, table_type):
relation_name = f"{table_type}_on_schema_change_append_new_columns"
vars = {"on_schema_change": "append_new_columns", "table_type": table_type}
args = ["run", "--select", relation_name, "--vars", json.dumps(vars)]

model_run_initial = run_dbt(args)
assert model_run_initial.results[0].status == RunStatus.Success

model_run_incremental = run_dbt(args)
assert model_run_incremental.results[0].status == RunStatus.Success

new_column_names = self._column_names(project, relation_name)
assert new_column_names == ["id", "name", "updated_at"]

@pytest.mark.parametrize("table_type", ["hive", "iceberg"])
def test__ignore(self, project, table_type):
relation_name = f"{table_type}_on_schema_change_ignore"
vars = {"on_schema_change": "ignore", "table_type": table_type}
args = ["run", "--select", relation_name, "--vars", json.dumps(vars)]

model_run_initial = run_dbt(args)
assert model_run_initial.results[0].status == RunStatus.Success

model_run_incremental = run_dbt(args)
assert model_run_incremental.results[0].status == RunStatus.Success

new_column_names = self._column_names(project, relation_name)
assert new_column_names == ["id", "name"]

@pytest.mark.parametrize("table_type", ["hive", "iceberg"])
def test__fail(self, project, table_type):
relation_name = f"{table_type}_on_schema_change_fail"
vars = {"on_schema_change": "fail", "table_type": table_type}
args = ["run", "--select", relation_name, "--vars", json.dumps(vars)]

model_run_initial = run_dbt(args)
assert model_run_initial.results[0].status == RunStatus.Success

model_run_incremental, log = run_dbt_and_capture(args, expect_pass=False)
assert model_run_incremental.results[0].status == RunStatus.Error
assert "The source and target schemas on this incremental model are out of sync!" in log

new_column_names = self._column_names(project, relation_name)
assert new_column_names == ["id", "name"]
6 changes: 3 additions & 3 deletions tests/unit/test_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@ def test__get_relation_type_with_unknown_type(self):


class TestAthenaRelation:
def test_render_hive_uses_hive_style_quotation(self):
def test_render_hive_uses_hive_style_quotation_and_only_schema_and_table_names(self):
relation = AthenaRelation.create(
identifier=TABLE_NAME,
database=DATA_CATALOG_NAME,
schema=DATABASE_NAME,
)
assert relation.render_hive() == f"`{DATA_CATALOG_NAME}`.`{DATABASE_NAME}`.`{TABLE_NAME}`"
assert relation.render_hive() == f"`{DATABASE_NAME}`.`{TABLE_NAME}`"

def test_render_hive_resets_quote_character_after_call(self):
def test_render_hive_resets_quote_character_and_include_policy_after_call(self):
relation = AthenaRelation.create(
identifier=TABLE_NAME,
database=DATA_CATALOG_NAME,
Expand Down
Loading