diff --git a/.changes/unreleased/Features-20230223-180923.yaml b/.changes/unreleased/Features-20230223-180923.yaml index 4faed7bc1..de98ef416 100644 --- a/.changes/unreleased/Features-20230223-180923.yaml +++ b/.changes/unreleased/Features-20230223-180923.yaml @@ -1,6 +1,6 @@ kind: Features -body: implement data_type_code_to_name on SparkConnectionManager +body: Enforce contracts on models materialized as tables and views time: 2023-02-23T18:09:23.787675-05:00 custom: - Author: michelleark - Issue: "639" + Author: michelleark emmyoop + Issue: 639 654 diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index 656e6b3a7..725277b3e 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -226,6 +226,9 @@ {% macro spark__create_view_as(relation, sql) -%} create or replace view {{ relation }} {{ comment_clause() }} + {% if config.get('contract', False) -%} + {{ get_assert_columns_equivalent(sql) }} + {%- endif %} as {{ sql }} {% endmacro %} diff --git a/tests/functional/adapter/test_constraints.py b/tests/functional/adapter/test_constraints.py index 14d7291d8..27cf59f1c 100644 --- a/tests/functional/adapter/test_constraints.py +++ b/tests/functional/adapter/test_constraints.py @@ -1,7 +1,8 @@ import pytest from dbt.tests.util import relation_from_name from dbt.tests.adapter.constraints.test_constraints import ( - BaseConstraintsColumnsEqual, + BaseTableConstraintsColumnsEqual, + BaseViewConstraintsColumnsEqual, BaseConstraintsRuntimeEnforcement ) from dbt.tests.adapter.constraints.fixtures import ( @@ -28,8 +29,7 @@ constraints_yml = model_schema_yml.replace("text", "string").replace("primary key", "") -@pytest.mark.skip_profile('spark_session', 'apache_spark', 'databricks_http_cluster') -class TestSparkConstraintsColumnsEqualPyodbc(BaseConstraintsColumnsEqual): +class PyodbcSetup: @pytest.fixture(scope="class") def models(self): return { @@ -68,8 +68,7 @@ def data_types(self, int_type, schema_int_type, string_type): ] -@pytest.mark.skip_profile('spark_session', 'apache_spark', 'databricks_sql_endpoint', 'databricks_cluster') -class TestSparkConstraintsColumnsEqualDatabricksHTTP(BaseConstraintsColumnsEqual): +class DatabricksHTTPSetup: @pytest.fixture(scope="class") def models(self): return { @@ -107,6 +106,26 @@ def data_types(self, int_type, schema_int_type, string_type): ] +@pytest.mark.skip_profile('spark_session', 'apache_spark', 'databricks_http_cluster') +class TestSparkTableConstraintsColumnsEqualPyodbc(PyodbcSetup, BaseTableConstraintsColumnsEqual): + pass + + +@pytest.mark.skip_profile('spark_session', 'apache_spark', 'databricks_http_cluster') +class TestSparkViewConstraintsColumnsEqualPyodbc(PyodbcSetup, BaseViewConstraintsColumnsEqual): + pass + + +@pytest.mark.skip_profile('spark_session', 'apache_spark', 'databricks_sql_endpoint', 'databricks_cluster') +class TestSparkTableConstraintsColumnsEqualDatabricksHTTP(DatabricksHTTPSetup, BaseTableConstraintsColumnsEqual): + pass + + +@pytest.mark.skip_profile('spark_session', 'apache_spark', 'databricks_sql_endpoint', 'databricks_cluster') +class TestSparkViewConstraintsColumnsEqualDatabricksHTTP(DatabricksHTTPSetup, BaseViewConstraintsColumnsEqual): + pass + + @pytest.mark.skip_profile('spark_session', 'apache_spark') class TestSparkConstraintsRuntimeEnforcement(BaseConstraintsRuntimeEnforcement): @pytest.fixture(scope="class")