diff --git a/superset-frontend/src/explore/exploreUtils.js b/superset-frontend/src/explore/exploreUtils.js index c53b91ba26698..efd67acf31286 100644 --- a/superset-frontend/src/explore/exploreUtils.js +++ b/superset-frontend/src/explore/exploreUtils.js @@ -22,6 +22,7 @@ import { useCallback, useEffect } from 'react'; import URI from 'urijs'; import { buildQueryContext, + ensureIsArray, getChartBuildQueryRegistry, getChartMetadataRegistry, } from '@superset-ui/core'; @@ -319,20 +320,14 @@ export const getSimpleSQLExpression = (subject, operator, comparator) => { expression += ` ${operator}`; const firstValue = isMulti && Array.isArray(comparator) ? comparator[0] : comparator; - let comparatorArray; - if (comparator === undefined || comparator === null) { - comparatorArray = []; - } else if (Array.isArray(comparator)) { - comparatorArray = comparator; - } else { - comparatorArray = [comparator]; - } + const comparatorArray = ensureIsArray(comparator); const isString = firstValue !== undefined && Number.isNaN(Number(firstValue)); const quote = isString ? "'" : ''; const [prefix, suffix] = isMulti ? ['(', ')'] : ['', '']; const formattedComparators = comparatorArray.map( - val => `${quote}${isString ? val.replace("'", "''") : val}${quote}`, + val => + `${quote}${isString ? String(val).replace("'", "''") : val}${quote}`, ); if (comparatorArray.length > 0) { expression += ` ${prefix}${formattedComparators.join(', ')}${suffix}`; diff --git a/superset/connectors/base/models.py b/superset/connectors/base/models.py index c4cc6f40639b2..d4d2f878ae827 100644 --- a/superset/connectors/base/models.py +++ b/superset/connectors/base/models.py @@ -359,10 +359,7 @@ def handle_single_value(value: Optional[FilterValue]) -> Optional[FilterValue]: if is_list_target and not isinstance(values, (tuple, list)): values = [values] # type: ignore elif not is_list_target and isinstance(values, (tuple, list)): - if values: - values = values[0] - else: - values = None + values = values[0] if values else None return values def external_metadata(self) -> List[Dict[str, str]]: diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 254f52a3045b8..9f745f96a1b48 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -53,7 +53,6 @@ from superset import app, db, is_feature_enabled, security_manager from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric -from superset.constants import NULL_STRING from superset.db_engine_specs.base import TimestampExpression from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import QueryObjectValidationError, SupersetSecurityException @@ -1065,23 +1064,36 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma target_column_is_numeric=col_obj.is_numeric, is_list_target=is_list_target, ) - if op in ( - utils.FilterOperator.IN.value, - utils.FilterOperator.NOT_IN.value, - ): - cond = col_obj.get_sqla_col().in_(eq) - if isinstance(eq, str) and NULL_STRING in eq: - cond = or_( - cond, - col_obj.get_sqla_col() # pylint: disable=singleton-comparison - == None, + if is_list_target: + assert isinstance(eq, (tuple, list)) + if len(eq) == 0: + raise QueryObjectValidationError( + _("Filter value list cannot be empty") ) + if None in eq: + eq = [x for x in eq if x is not None] + is_null_cond = col_obj.get_sqla_col().is_(None) + if eq: + cond = or_(is_null_cond, col_obj.get_sqla_col().in_(eq)) + else: + cond = is_null_cond + else: + cond = col_obj.get_sqla_col().in_(eq) if op == utils.FilterOperator.NOT_IN.value: cond = ~cond where_clause_and.append(cond) + elif op == utils.FilterOperator.IS_NULL.value: + where_clause_and.append(col_obj.get_sqla_col().is_(None)) + elif op == utils.FilterOperator.IS_NOT_NULL.value: + where_clause_and.append(col_obj.get_sqla_col().isnot(None)) else: - if col_obj.is_numeric: - eq = utils.cast_to_num(flt["val"]) + if eq is None: + raise QueryObjectValidationError( + _( + "Must specify a value for filters " + "with comparison operators" + ) + ) if op == utils.FilterOperator.EQUALS.value: where_clause_and.append(col_obj.get_sqla_col() == eq) elif op == utils.FilterOperator.NOT_EQUALS.value: @@ -1096,16 +1108,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma where_clause_and.append(col_obj.get_sqla_col() <= eq) elif op == utils.FilterOperator.LIKE.value: where_clause_and.append(col_obj.get_sqla_col().like(eq)) - elif op == utils.FilterOperator.IS_NULL.value: - where_clause_and.append( - col_obj.get_sqla_col() # pylint: disable=singleton-comparison - == None - ) - elif op == utils.FilterOperator.IS_NOT_NULL.value: - where_clause_and.append( - col_obj.get_sqla_col() # pylint: disable=singleton-comparison - != None - ) else: raise QueryObjectValidationError( _("Invalid filter operation type: %(op)s", op=op) diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index ff402cc3ccc72..46f706aa9dcc5 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -1046,7 +1046,12 @@ def test_chart_data_applied_time_extras(self): data = json.loads(rv.data.decode("utf-8")) self.assertEqual( data["result"][0]["applied_filters"], - [{"column": "gender"}, {"column": "__time_range"},], + [ + {"column": "gender"}, + {"column": "num"}, + {"column": "name"}, + {"column": "__time_range"}, + ], ) self.assertEqual( data["result"][0]["rejected_filters"], diff --git a/tests/fixtures/query_context.py b/tests/fixtures/query_context.py index 0bacb671e87df..38e156aae4823 100644 --- a/tests/fixtures/query_context.py +++ b/tests/fixtures/query_context.py @@ -38,7 +38,11 @@ "time_range": "100 years ago : now", "timeseries_limit": 0, "timeseries_limit_metric": None, - "filters": [{"col": "gender", "op": "==", "val": "boy"}], + "filters": [ + {"col": "gender", "op": "==", "val": "boy"}, + {"col": "num", "op": "IS NOT NULL"}, + {"col": "name", "op": "NOT IN", "val": ["", '"abc"']}, + ], "having": "", "having_filters": [], "where": "", diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py index c7ec3271dda89..377f717be9394 100644 --- a/tests/query_context_tests.py +++ b/tests/query_context_tests.py @@ -14,6 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import re + import pytest from superset import db @@ -305,8 +307,15 @@ def test_query_response_type(self): assert len(responses) == 1 response = responses["queries"][0] assert len(response) == 2 + sql_text = response["query"] assert response["language"] == "sql" - assert "SELECT" in response["query"] + assert "SELECT" in sql_text + assert re.search(r'[`"\[]?num[`"\]]? IS NOT NULL', sql_text) + assert re.search( + r"""NOT \([`"\[]?name[`"\]]? IS NULL[\s\n]* """ + r"""OR [`"\[]?name[`"\]]? IN \('abc'\)\)""", + sql_text, + ) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_fetch_values_predicate_in_query(self):