Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SIP-15] Fix time range endpoints decoding #8481

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions superset/views/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,35 +207,33 @@ def get_time_range_endpoints(
form_data: Dict[str, Any], slc: Optional[models.Slice]
) -> Optional[Tuple[TimeRangeEndpoint, TimeRangeEndpoint]]:
"""
Get the slice aware time range endpoints falling back to the SQL database specific
definition or default if not defined.
Get the slice aware time range endpoints from the form-data falling back to the SQL
database specific definition or default if not defined.

For SIP-15 all new slices use the [start, end) interval which is consistent with the
Druid REST API.
native Druid connector.

:param form_data: The form-data
:param slc: The chart
:returns: The time range endpoints tuple
"""

time_range_endpoints = form_data.get("time_range_endpoints")
endpoints = form_data.get("time_range_endpoints")

if time_range_endpoints:
return time_range_endpoints
if slc and not endpoints:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is really just a refactor of the old logic to ensure that we're not repeating ourselves when defining the interval.

try:
_, datasource_type = get_datasource_info(None, None, form_data)
except SupersetException:
return None

try:
_, datasource_type = get_datasource_info(None, None, form_data)
except SupersetException:
return None

if datasource_type == "table":
if slc:
if datasource_type == "table":
endpoints = slc.datasource.database.get_extra().get("time_range_endpoints")

if not endpoints:
endpoints = app.config["SIP_15_DEFAULT_TIME_RANGE_ENDPOINTS"]

start, end = endpoints
return (TimeRangeEndpoint(start), TimeRangeEndpoint(end))
if endpoints:
start, end = endpoints
return (TimeRangeEndpoint(start), TimeRangeEndpoint(end))

return (TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE)
41 changes: 40 additions & 1 deletion tests/utils_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import uuid
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from unittest.mock import patch
from unittest.mock import Mock, patch

import numpy
from flask import Flask
Expand Down Expand Up @@ -47,10 +47,12 @@
parse_past_timedelta,
setup_cache,
split,
TimeRangeEndpoint,
validate_json,
zlib_compress,
zlib_decompress,
)
from superset.views.utils import get_time_range_endpoints


def mock_parse_human_datetime(s):
Expand Down Expand Up @@ -881,3 +883,40 @@ def test_get_or_create_db(self):
def test_get_or_create_db_invalid_uri(self):
with self.assertRaises(ArgumentError):
get_or_create_db("test_db", "yoursql:superset.db/()")

def test_get_time_range_endpoints(self):
self.assertEqual(
get_time_range_endpoints(form_data={}, slc=None),
(TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE),
)

self.assertEqual(
get_time_range_endpoints(
form_data={"time_range_endpoints": ["inclusive", "inclusive"]}, slc=None
),
(TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.INCLUSIVE),
)

self.assertEqual(
get_time_range_endpoints(form_data={"datasource": "1_druid"}, slc=None),
(TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.EXCLUSIVE),
)

slc = Mock()
slc.datasource.database.get_extra.return_value = {}

self.assertEqual(
get_time_range_endpoints(form_data={"datasource": "1__table"}, slc=slc),
(TimeRangeEndpoint.UNKNOWN, TimeRangeEndpoint.INCLUSIVE),
)

slc.datasource.database.get_extra.return_value = {
"time_range_endpoints": ["inclusive", "inclusive"]
}

self.assertEqual(
get_time_range_endpoints(form_data={"datasource": "1__table"}, slc=slc),
(TimeRangeEndpoint.INCLUSIVE, TimeRangeEndpoint.INCLUSIVE),
)

self.assertIsNone(get_time_range_endpoints(form_data={}, slc=slc))