Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finish Constraint Support for Spark #747

Merged
merged 9 commits into from
May 24, 2023
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20230427-123135.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: All constraint types are supported, but not enforced.
time: 2023-04-27T12:31:35.011284-04:00
custom:
Author: peterallenwebb
Issue: 656 657
17 changes: 14 additions & 3 deletions dbt/adapters/spark/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@
from concurrent.futures import Future
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Union, Type, Tuple, Callable

from typing_extensions import TypeAlias

import agate
from dbt.contracts.relation import RelationType

import dbt
import dbt.exceptions

from dbt.adapters.base import AdapterConfig, PythonJobHelper
from dbt.adapters.base.impl import catch_as_completed
from dbt.contracts.connection import AdapterResponse
from dbt.adapters.base.impl import catch_as_completed, ConstraintSupport
from dbt.adapters.sql import SQLAdapter
from dbt.adapters.spark import SparkConnectionManager
from dbt.adapters.spark import SparkRelation
Expand All @@ -23,6 +22,9 @@
)
from dbt.adapters.base import BaseRelation
from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER
from dbt.contracts.connection import AdapterResponse
from dbt.contracts.graph.nodes import ConstraintType
from dbt.contracts.relation import RelationType
from dbt.events import AdapterLogger
from dbt.utils import executor, AttrDict

Expand Down Expand Up @@ -79,6 +81,7 @@ class SparkAdapter(SQLAdapter):
INFORMATION_COLUMNS_REGEX = re.compile(r"^ \|-- (.*): (.*) \(nullable = (.*)\b", re.MULTILINE)
INFORMATION_OWNER_REGEX = re.compile(r"^Owner: (.*)$", re.MULTILINE)
INFORMATION_STATISTICS_REGEX = re.compile(r"^Statistics: (.*)$", re.MULTILINE)

HUDI_METADATA_COLUMNS = [
"_hoodie_commit_time",
"_hoodie_commit_seqno",
Expand All @@ -87,6 +90,14 @@ class SparkAdapter(SQLAdapter):
"_hoodie_file_name",
]

CONSTRAINT_SUPPORT = {
ConstraintType.check: ConstraintSupport.NOT_ENFORCED,
ConstraintType.not_null: ConstraintSupport.NOT_ENFORCED,
ConstraintType.unique: ConstraintSupport.NOT_ENFORCED,
ConstraintType.primary_key: ConstraintSupport.NOT_ENFORCED,
ConstraintType.foreign_key: ConstraintSupport.NOT_ENFORCED,
}

Relation: TypeAlias = SparkRelation
RelationInfo = Tuple[str, str, str]
Column: TypeAlias = SparkColumn
Expand Down
22 changes: 9 additions & 13 deletions dbt/include/spark/macros/adapters.sql
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@
{% macro spark__persist_constraints(relation, model) %}
{%- set contract_config = config.get('contract') -%}
{% if contract_config.enforced and config.get('file_format', 'delta') == 'delta' %}
{% do alter_table_add_constraints(relation, model.columns) %}
{% do alter_table_add_constraints(relation, model.constraints) %}
{% do alter_column_set_constraints(relation, model.columns) %}
{% endif %}
{% endmacro %}
Expand All @@ -192,18 +192,14 @@
{{ return(adapter.dispatch('alter_table_add_constraints', 'dbt')(relation, constraints)) }}
{% endmacro %}

{% macro spark__alter_table_add_constraints(relation, column_dict) %}

{% for column_name in column_dict %}
{% set constraints = column_dict[column_name]['constraints'] %}
{% for constraint in constraints %}
{% if constraint.type == 'check' and not is_incremental() %}
{%- set constraint_hash = local_md5(column_name ~ ";" ~ constraint.expression ~ ";" ~ loop.index) -%}
{% call statement() %}
alter table {{ relation }} add constraint {{ constraint_hash }} check {{ constraint.expression }};
{% endcall %}
{% endif %}
{% endfor %}
{% macro spark__alter_table_add_constraints(relation, constraints) %}
{% for constraint in constraints %}
{% if constraint.type == 'check' and not is_incremental() %}
{%- set constraint_hash = local_md5(column_name ~ ";" ~ constraint.expression ~ ";" ~ loop.index) -%}
{% call statement() %}
alter table {{ relation }} add constraint {{ constraint.name if constraint.name else constraint_hash }} check {{ constraint.expression }};
{% endcall %}
{% endif %}
{% endfor %}
{% endmacro %}

Expand Down
46 changes: 46 additions & 0 deletions tests/functional/adapter/test_constraints.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pytest
from dbt.tests.adapter.constraints.test_constraints import (
BaseModelConstraintsRuntimeEnforcement,
BaseTableConstraintsColumnsEqual,
BaseViewConstraintsColumnsEqual,
BaseIncrementalConstraintsColumnsEqual,
Expand All @@ -9,6 +10,7 @@
BaseIncrementalConstraintsRollback,
)
from dbt.tests.adapter.constraints.fixtures import (
constrained_model_schema_yml,
my_model_sql,
my_model_wrong_order_sql,
my_model_wrong_name_sql,
Expand Down Expand Up @@ -37,9 +39,26 @@
'2019-01-01' as date_day ) as model_subq
"""

_expected_sql_spark_model_constraints = """
create or replace table <model_identifier>
using delta
as
select
id,
color,
date_day
from

( select
1 as id,
'blue' as color,
'2019-01-01' as date_day ) as model_subq
"""

# Different on Spark:
# - does not support a data type named 'text' (TODO handle this in the base test classes using string_type
constraints_yml = model_schema_yml.replace("text", "string").replace("primary key", "")
model_constraints_yml = constrained_model_schema_yml.replace("text", "string")


class PyodbcSetup:
Expand Down Expand Up @@ -246,9 +265,11 @@ def expected_error_messages(self):
return [
"violate the new CHECK constraint",
"DELTA_NEW_CHECK_CONSTRAINT_VIOLATION",
"DELTA_NEW_NOT_NULL_VIOLATION",
"violate the new NOT NULL constraint",
"(id > 0) violated by row with values:", # incremental mats
"DELTA_VIOLATE_CONSTRAINT_WITH_VALUES", # incremental mats
"NOT NULL constraint violated for column",
]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new error messages were needed here, as it appears the messages coming from Databricks have changed slightly, and some of the messages were also truncated differently.


def assert_expected_error_messages(self, error_message, expected_error_messages):
Expand Down Expand Up @@ -289,3 +310,28 @@ def models(self):
"my_model.sql": my_incremental_model_sql,
"constraints_schema.yml": constraints_yml,
}


# TODO: Like the tests above, this does test that model-level constraints don't
# result in errors, but it does not verify that they are actually present in
# Spark and that the ALTER TABLE statement actually ran.
@pytest.mark.skip_profile("spark_session", "apache_spark")
class TestSparkModelConstraintsRuntimeEnforcement(BaseModelConstraintsRuntimeEnforcement):
@pytest.fixture(scope="class")
def project_config_update(self):
return {
"models": {
"+file_format": "delta",
}
}

@pytest.fixture(scope="class")
def models(self):
return {
"my_model.sql": my_incremental_model_sql,
"constraints_schema.yml": model_constraints_yml,
}

@pytest.fixture(scope="class")
def expected_sql(self):
return _expected_sql_spark_model_constraints