Skip to content

Commit

Permalink
fix: IS NULL filter operator for numeric columns (apache#13496)
Browse files Browse the repository at this point in the history
  • Loading branch information
ktmud authored and Allan Caetano de Oliveira committed May 21, 2021
1 parent 9692f41 commit 94dcfa9
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 39 deletions.
13 changes: 4 additions & 9 deletions superset-frontend/src/explore/exploreUtils.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import { useCallback, useEffect } from 'react';
import URI from 'urijs';
import {
buildQueryContext,
ensureIsArray,
getChartBuildQueryRegistry,
getChartMetadataRegistry,
} from '@superset-ui/core';
Expand Down Expand Up @@ -319,20 +320,14 @@ export const getSimpleSQLExpression = (subject, operator, comparator) => {
expression += ` ${operator}`;
const firstValue =
isMulti && Array.isArray(comparator) ? comparator[0] : comparator;
let comparatorArray;
if (comparator === undefined || comparator === null) {
comparatorArray = [];
} else if (Array.isArray(comparator)) {
comparatorArray = comparator;
} else {
comparatorArray = [comparator];
}
const comparatorArray = ensureIsArray(comparator);
const isString =
firstValue !== undefined && Number.isNaN(Number(firstValue));
const quote = isString ? "'" : '';
const [prefix, suffix] = isMulti ? ['(', ')'] : ['', ''];
const formattedComparators = comparatorArray.map(
val => `${quote}${isString ? val.replace("'", "''") : val}${quote}`,
val =>
`${quote}${isString ? String(val).replace("'", "''") : val}${quote}`,
);
if (comparatorArray.length > 0) {
expression += ` ${prefix}${formattedComparators.join(', ')}${suffix}`;
Expand Down
5 changes: 1 addition & 4 deletions superset/connectors/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,10 +359,7 @@ def handle_single_value(value: Optional[FilterValue]) -> Optional[FilterValue]:
if is_list_target and not isinstance(values, (tuple, list)):
values = [values] # type: ignore
elif not is_list_target and isinstance(values, (tuple, list)):
if values:
values = values[0]
else:
values = None
values = values[0] if values else None
return values

def external_metadata(self) -> List[Dict[str, str]]:
Expand Down
48 changes: 25 additions & 23 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@

from superset import app, db, is_feature_enabled, security_manager
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
from superset.constants import NULL_STRING
from superset.db_engine_specs.base import TimestampExpression
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import QueryObjectValidationError, SupersetSecurityException
Expand Down Expand Up @@ -1065,23 +1064,36 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
target_column_is_numeric=col_obj.is_numeric,
is_list_target=is_list_target,
)
if op in (
utils.FilterOperator.IN.value,
utils.FilterOperator.NOT_IN.value,
):
cond = col_obj.get_sqla_col().in_(eq)
if isinstance(eq, str) and NULL_STRING in eq:
cond = or_(
cond,
col_obj.get_sqla_col() # pylint: disable=singleton-comparison
== None,
if is_list_target:
assert isinstance(eq, (tuple, list))
if len(eq) == 0:
raise QueryObjectValidationError(
_("Filter value list cannot be empty")
)
if None in eq:
eq = [x for x in eq if x is not None]
is_null_cond = col_obj.get_sqla_col().is_(None)
if eq:
cond = or_(is_null_cond, col_obj.get_sqla_col().in_(eq))
else:
cond = is_null_cond
else:
cond = col_obj.get_sqla_col().in_(eq)
if op == utils.FilterOperator.NOT_IN.value:
cond = ~cond
where_clause_and.append(cond)
elif op == utils.FilterOperator.IS_NULL.value:
where_clause_and.append(col_obj.get_sqla_col().is_(None))
elif op == utils.FilterOperator.IS_NOT_NULL.value:
where_clause_and.append(col_obj.get_sqla_col().isnot(None))
else:
if col_obj.is_numeric:
eq = utils.cast_to_num(flt["val"])
if eq is None:
raise QueryObjectValidationError(
_(
"Must specify a value for filters "
"with comparison operators"
)
)
if op == utils.FilterOperator.EQUALS.value:
where_clause_and.append(col_obj.get_sqla_col() == eq)
elif op == utils.FilterOperator.NOT_EQUALS.value:
Expand All @@ -1096,16 +1108,6 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
where_clause_and.append(col_obj.get_sqla_col() <= eq)
elif op == utils.FilterOperator.LIKE.value:
where_clause_and.append(col_obj.get_sqla_col().like(eq))
elif op == utils.FilterOperator.IS_NULL.value:
where_clause_and.append(
col_obj.get_sqla_col() # pylint: disable=singleton-comparison
== None
)
elif op == utils.FilterOperator.IS_NOT_NULL.value:
where_clause_and.append(
col_obj.get_sqla_col() # pylint: disable=singleton-comparison
!= None
)
else:
raise QueryObjectValidationError(
_("Invalid filter operation type: %(op)s", op=op)
Expand Down
7 changes: 6 additions & 1 deletion tests/charts/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1046,7 +1046,12 @@ def test_chart_data_applied_time_extras(self):
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(
data["result"][0]["applied_filters"],
[{"column": "gender"}, {"column": "__time_range"},],
[
{"column": "gender"},
{"column": "num"},
{"column": "name"},
{"column": "__time_range"},
],
)
self.assertEqual(
data["result"][0]["rejected_filters"],
Expand Down
6 changes: 5 additions & 1 deletion tests/fixtures/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@
"time_range": "100 years ago : now",
"timeseries_limit": 0,
"timeseries_limit_metric": None,
"filters": [{"col": "gender", "op": "==", "val": "boy"}],
"filters": [
{"col": "gender", "op": "==", "val": "boy"},
{"col": "num", "op": "IS NOT NULL"},
{"col": "name", "op": "NOT IN", "val": ["<NULL>", '"abc"']},
],
"having": "",
"having_filters": [],
"where": "",
Expand Down
11 changes: 10 additions & 1 deletion tests/query_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re

import pytest

from superset import db
Expand Down Expand Up @@ -305,8 +307,15 @@ def test_query_response_type(self):
assert len(responses) == 1
response = responses["queries"][0]
assert len(response) == 2
sql_text = response["query"]
assert response["language"] == "sql"
assert "SELECT" in response["query"]
assert "SELECT" in sql_text
assert re.search(r'[`"\[]?num[`"\]]? IS NOT NULL', sql_text)
assert re.search(
r"""NOT \([`"\[]?name[`"\]]? IS NULL[\s\n]* """
r"""OR [`"\[]?name[`"\]]? IN \('abc'\)\)""",
sql_text,
)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_fetch_values_predicate_in_query(self):
Expand Down

0 comments on commit 94dcfa9

Please sign in to comment.