Skip to content

Commit

Permalink
fix: % replace in values_for_column (#28271)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored and sadpandajoe committed May 20, 2024
1 parent 63e6721 commit 6de2b96
Showing 1 changed file with 98 additions and 13 deletions.
111 changes: 98 additions & 13 deletions tests/unit_tests/models/helpers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,24 @@

# pylint: disable=import-outside-toplevel

from __future__ import annotations

from contextlib import contextmanager
from typing import TYPE_CHECKING

import pytest
from pytest_mock import MockerFixture
from sqlalchemy import create_engine
from sqlalchemy.orm.session import Session
from sqlalchemy.pool import StaticPool

if TYPE_CHECKING:
from superset.models.core import Database

def test_values_for_column(mocker: MockerFixture, session: Session) -> None:
"""
Test the `values_for_column` method.

NULL values should be returned as `None`, not `np.nan`, since NaN cannot be
serialized to JSON.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn
@pytest.fixture()
def database(mocker: MockerFixture, session: Session) -> Database:
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database

SqlaTable.metadata.create_all(session.get_bind())
Expand All @@ -42,13 +44,12 @@ def test_values_for_column(mocker: MockerFixture, session: Session) -> None:
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)

database = Database(database_name="db", sqlalchemy_uri="sqlite://")

connection = engine.raw_connection()
connection.execute("CREATE TABLE t (c INTEGER)")
connection.execute("INSERT INTO t VALUES (1)")
connection.execute("INSERT INTO t VALUES (NULL)")
connection.execute("CREATE TABLE t (a INTEGER, b TEXT)")
connection.execute("INSERT INTO t VALUES (1, 'Alice')")
connection.execute("INSERT INTO t VALUES (NULL, 'Bob')")
connection.commit()

# since we're using an in-memory SQLite database, make sure we always
Expand All @@ -63,10 +64,94 @@ def mock_get_sqla_engine_with_context():
new=mock_get_sqla_engine_with_context,
)

return database


def test_values_for_column(database: Database) -> None:
"""
Test the `values_for_column` method.
NULL values should be returned as `None`, not `np.nan`, since NaN cannot be
serialized to JSON.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn

table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[TableColumn(column_name="a")],
)
assert table.values_for_column("a") == [1, None]


def test_values_for_column_calculated(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test that calculated columns work.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn

table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[
TableColumn(
column_name="starts_with_A",
expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END",
)
],
)
assert table.values_for_column("starts_with_A") == ["yes", "nope"]


def test_values_for_column_double_percents(
mocker: MockerFixture,
database: Database,
) -> None:
"""
Test the behavior of `double_percents`.
"""
from superset.connectors.sqla.models import SqlaTable, TableColumn

with database.get_sqla_engine() as engine:
engine.dialect.identifier_preparer._double_percents = "pyformat"

table = SqlaTable(
database=database,
schema=None,
table_name="t",
columns=[TableColumn(column_name="c")],
columns=[
TableColumn(
column_name="starts_with_A",
expression="CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END",
)
],
)

mutate_sql_based_on_config = mocker.patch.object(
database,
"mutate_sql_based_on_config",
side_effect=lambda sql: sql,
)
pd = mocker.patch("superset.models.helpers.pd")

table.values_for_column("starts_with_A")

# make sure the SQL originally had double percents
mutate_sql_based_on_config.assert_called_with(
"SELECT DISTINCT CASE WHEN b LIKE 'A%%' THEN 'yes' ELSE 'nope' END "
"AS column_values \nFROM t\n LIMIT 10000 OFFSET 0"
)
assert table.values_for_column("c") == [1, None]
# make sure final query has single percents
with database.get_sqla_engine() as engine:
pd.read_sql_query.assert_called_with(
sql=(
"SELECT DISTINCT CASE WHEN b LIKE 'A%' THEN 'yes' ELSE 'nope' END "
"AS column_values \nFROM t\n LIMIT 10000 OFFSET 0"
),
con=engine,
)

0 comments on commit 6de2b96

Please sign in to comment.