diff --git a/.changes/unreleased/Features-20230604-080052.yaml b/.changes/unreleased/Features-20230604-080052.yaml new file mode 100644 index 00000000000..50e8dda5bdb --- /dev/null +++ b/.changes/unreleased/Features-20230604-080052.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Fix null-safe equals comparison via `equals` +time: 2023-06-04T08:00:52.537967-06:00 +custom: + Author: dbeatty10 + Issue: "7778" diff --git a/tests/adapter/dbt/tests/adapter/utils/base_utils.py b/tests/adapter/dbt/tests/adapter/utils/base_utils.py index 75d621c2bf9..622b4ab4224 100644 --- a/tests/adapter/dbt/tests/adapter/utils/base_utils.py +++ b/tests/adapter/dbt/tests/adapter/utils/base_utils.py @@ -2,9 +2,11 @@ from dbt.tests.util import run_dbt macros__equals_sql = """ -{% macro equals(actual, expected) %} -{# -- actual is not distinct from expected #} -(({{ actual }} = {{ expected }}) or ({{ actual }} is null and {{ expected }} is null)) +{% macro equals(expr1, expr2) -%} +case when (({{ expr1 }} = {{ expr2 }}) or ({{ expr1 }} is null and {{ expr2 }} is null)) + then 0 + else 1 +end = 0 {% endmacro %} """ @@ -15,6 +17,15 @@ {% endtest %} """ +macros__replace_empty_sql = """ +{% macro replace_empty(expr) -%} +case + when {{ expr }} = 'EMPTY' then '' + else {{ expr }} +end +{% endmacro %} +""" + class BaseUtils: # setup @@ -23,6 +34,7 @@ def macros(self): return { "equals.sql": macros__equals_sql, "test_assert_equal.sql": macros__test_assert_equal_sql, + "replace_empty.sql": macros__replace_empty_sql, } # make it possible to dynamically update the macro call with a namespace diff --git a/tests/adapter/dbt/tests/adapter/utils/fixture_concat.py b/tests/adapter/dbt/tests/adapter/utils/fixture_concat.py index d1630af3eea..8421d53eb66 100644 --- a/tests/adapter/dbt/tests/adapter/utils/fixture_concat.py +++ b/tests/adapter/dbt/tests/adapter/utils/fixture_concat.py @@ -1,18 +1,29 @@ # concat +# https://github.com/dbt-labs/dbt-core/issues/4725 seeds__data_concat_csv = """input_1,input_2,output a,b,ab -a,,a -,b,b -,, +a,EMPTY,a +EMPTY,b,b +EMPTY,EMPTY,EMPTY """ models__test_concat_sql = """ -with data as ( +with seed_data as ( select * from {{ ref('data_concat') }} +), + +data as ( + + select + {{ replace_empty('input_1') }} as input_1, + {{ replace_empty('input_2') }} as input_2, + {{ replace_empty('output') }} as output + from seed_data + ) select diff --git a/tests/adapter/dbt/tests/adapter/utils/fixture_equals.py b/tests/adapter/dbt/tests/adapter/utils/fixture_equals.py new file mode 100644 index 00000000000..67a6699c6ef --- /dev/null +++ b/tests/adapter/dbt/tests/adapter/utils/fixture_equals.py @@ -0,0 +1,41 @@ +# equals + +SEEDS__DATA_EQUALS_CSV = """key_name,x,y,expected +1,1,1,same +2,1,2,different +3,1,null,different +4,2,1,different +5,2,2,same +6,2,null,different +7,null,1,different +8,null,2,different +9,null,null,same +""" + + +MODELS__EQUAL_VALUES_SQL = """ +with data as ( + + select * from {{ ref('data_equals') }} + +) + +select * +from data +where + {{ equals('x', 'y') }} +""" + + +MODELS__NOT_EQUAL_VALUES_SQL = """ +with data as ( + + select * from {{ ref('data_equals') }} + +) + +select * +from data +where + not {{ equals('x', 'y') }} +""" diff --git a/tests/adapter/dbt/tests/adapter/utils/fixture_hash.py b/tests/adapter/dbt/tests/adapter/utils/fixture_hash.py index 39d8f02704c..91f366fc504 100644 --- a/tests/adapter/dbt/tests/adapter/utils/fixture_hash.py +++ b/tests/adapter/dbt/tests/adapter/utils/fixture_hash.py @@ -1,18 +1,28 @@ # hash +# https://github.com/dbt-labs/dbt-core/issues/4725 seeds__data_hash_csv = """input_1,output ab,187ef4436122d1cc2f40dc2b92f0eba0 a,0cc175b9c0f1b6a831c399e269772661 1,c4ca4238a0b923820dcc509a6f75849b -,d41d8cd98f00b204e9800998ecf8427e +EMPTY,d41d8cd98f00b204e9800998ecf8427e """ models__test_hash_sql = """ -with data as ( +with seed_data as ( select * from {{ ref('data_hash') }} +), + +data as ( + + select + {{ replace_empty('input_1') }} as input_1, + {{ replace_empty('output') }} as output + from seed_data + ) select diff --git a/tests/adapter/dbt/tests/adapter/utils/fixture_null_compare.py b/tests/adapter/dbt/tests/adapter/utils/fixture_null_compare.py index 8a524010082..9af2f9a2e32 100644 --- a/tests/adapter/dbt/tests/adapter/utils/fixture_null_compare.py +++ b/tests/adapter/dbt/tests/adapter/utils/fixture_null_compare.py @@ -8,7 +8,7 @@ MODELS__TEST_MIXED_NULL_COMPARE_YML = """ version: 2 models: - - name: test_null_compare + - name: test_mixed_null_compare tests: - assert_equal: actual: actual diff --git a/tests/adapter/dbt/tests/adapter/utils/test_equals.py b/tests/adapter/dbt/tests/adapter/utils/test_equals.py new file mode 100644 index 00000000000..51e7fe84bd3 --- /dev/null +++ b/tests/adapter/dbt/tests/adapter/utils/test_equals.py @@ -0,0 +1,54 @@ +import pytest +from dbt.tests.adapter.utils.base_utils import macros__equals_sql +from dbt.tests.adapter.utils.fixture_equals import ( + SEEDS__DATA_EQUALS_CSV, + MODELS__EQUAL_VALUES_SQL, + MODELS__NOT_EQUAL_VALUES_SQL, +) +from dbt.tests.util import run_dbt, relation_from_name + + +class BaseEquals: + @pytest.fixture(scope="class") + def macros(self): + return { + "equals.sql": macros__equals_sql, + } + + @pytest.fixture(scope="class") + def seeds(self): + return { + "data_equals.csv": SEEDS__DATA_EQUALS_CSV, + } + + @pytest.fixture(scope="class") + def models(self): + return { + "equal_values.sql": MODELS__EQUAL_VALUES_SQL, + "not_equal_values.sql": MODELS__NOT_EQUAL_VALUES_SQL, + } + + def test_equal_values(self, project): + run_dbt(["seed"]) + run_dbt(["run"]) + + # There are 9 cases total; 3 are equal and 6 are not equal + + # 3 are equal + relation = relation_from_name(project.adapter, "equal_values") + result = project.run_sql( + f"select count(*) as num_rows from {relation} where expected = 'same'", fetch="one" + ) + assert result[0] == 3 + + # 6 are not equal + relation = relation_from_name(project.adapter, "not_equal_values") + result = project.run_sql( + f"select count(*) as num_rows from {relation} where expected = 'different'", + fetch="one", + ) + assert result[0] == 6 + + +class TestEquals(BaseEquals): + pass diff --git a/tests/adapter/dbt/tests/adapter/utils/test_null_compare.py b/tests/adapter/dbt/tests/adapter/utils/test_null_compare.py index 58a1c9daaf9..eac901f3972 100644 --- a/tests/adapter/dbt/tests/adapter/utils/test_null_compare.py +++ b/tests/adapter/dbt/tests/adapter/utils/test_null_compare.py @@ -14,8 +14,8 @@ class BaseMixedNullCompare(BaseUtils): @pytest.fixture(scope="class") def models(self): return { - "test_mixed_null_compare.yml": MODELS__TEST_MIXED_NULL_COMPARE_SQL, - "test_mixed_null_compare.sql": MODELS__TEST_MIXED_NULL_COMPARE_YML, + "test_mixed_null_compare.yml": MODELS__TEST_MIXED_NULL_COMPARE_YML, + "test_mixed_null_compare.sql": MODELS__TEST_MIXED_NULL_COMPARE_SQL, } def test_build_assert_equal(self, project): @@ -32,7 +32,7 @@ def models(self): } -class TestMixedNullCompare(BaseNullCompare): +class TestMixedNullCompare(BaseMixedNullCompare): pass