diff --git a/superset/views/utils.py b/superset/views/utils.py index a5d5073c8951e..1bf4aff2d9fff 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -207,35 +207,33 @@ def get_time_range_endpoints( form_data: Dict[str, Any], slc: Optional[models.Slice] ) -> Optional[Tuple[TimeRangeEndpoint, TimeRangeEndpoint]]: """ - Get the slice aware time range endpoints falling back to the SQL database specific - definition or default if not defined. + Get the slice aware time range endpoints from the form-data falling back to the SQL + database specific definition or default if not defined. For SIP-15 all new slices use the [start, end) interval which is consistent with the - Druid REST API. + native Druid connector. :param form_data: The form-data :param slc: The chart :returns: The time range endpoints tuple """ - time_range_endpoints = form_data.get("time_range_endpoints") + endpoints = form_data.get("time_range_endpoints") - if time_range_endpoints: - return time_range_endpoints + if slc and not endpoints: + try: + _, datasource_type = get_datasource_info(None, None, form_data) + except SupersetException: + return None - try: - _, datasource_type = get_datasource_info(None, None, form_data) - except SupersetException: - return None - - if datasource_type == "table": - if slc: + if datasource_type == "table": endpoints = slc.datasource.database.get_extra().get("time_range_endpoints") if not endpoints: endpoints = app.config["SIP_15_DEFAULT_TIME_RANGE_ENDPOINTS"] - start, end = endpoints - return (TimeRangeEndpoint(start), TimeRangeEndpoint(end)) + if endpoints: + start, end = endpoints + return (TimeRangeEndpoint(start), TimeRangeEndpoint(end)) return (TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE) diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 02703f0279805..8268f7fca5b90 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -18,7 +18,7 @@ import uuid from datetime import date, datetime, time, timedelta from decimal import Decimal -from unittest.mock import patch +from unittest.mock import Mock, patch import numpy from flask import Flask @@ -47,10 +47,12 @@ parse_past_timedelta, setup_cache, split, + TimeRangeEndpoint, validate_json, zlib_compress, zlib_decompress, ) +from superset.views.utils import get_time_range_endpoints def mock_parse_human_datetime(s): @@ -881,3 +883,40 @@ def test_get_or_create_db(self): def test_get_or_create_db_invalid_uri(self): with self.assertRaises(ArgumentError): get_or_create_db("test_db", "yoursql:superset.db/()") + + def test_get_time_range_endpoints(self): + self.assertEqual( + get_time_range_endpoints(form_data={}, slc=None), + (TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE), + ) + + self.assertEqual( + get_time_range_endpoints( + form_data={"time_range_endpoints": ["inclusive", "inclusive"]}, slc=None + ), + (TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.INCLUSIVE), + ) + + self.assertEqual( + get_time_range_endpoints(form_data={"datasource": "1_druid"}, slc=None), + (TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE), + ) + + slc = Mock() + slc.datasource.database.get_extra.return_value = {} + + self.assertEqual( + get_time_range_endpoints(form_data={"datasource": "1__table"}, slc=slc), + (TimeRangeEndpoint.UNKNOWN, TimeRangeEndpoint.INCLUSIVE), + ) + + slc.datasource.database.get_extra.return_value = { + "time_range_endpoints": ["inclusive", "inclusive"] + } + + self.assertEqual( + get_time_range_endpoints(form_data={"datasource": "1__table"}, slc=slc), + (TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.INCLUSIVE), + ) + + self.assertIsNone(get_time_range_endpoints(form_data={}, slc=slc))