From 0a00153375fb69891c2a9f0115a33cdf5551b2d6 Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Wed, 24 Feb 2021 07:43:47 +0200 Subject: [PATCH] feat(chart-data): add rowcount, timegrain and column result types (#13271) * feat(chart-data): add rowcount, timegrain and column result types * break out actions from query_context * rename module --- superset/charts/schemas.py | 27 +++-- superset/common/query_actions.py | 182 ++++++++++++++++++++++++++++ superset/common/query_context.py | 117 ++++++------------ superset/common/query_object.py | 26 +++- superset/connectors/druid/models.py | 5 + superset/connectors/sqla/models.py | 14 ++- superset/utils/cache.py | 1 + superset/utils/core.py | 19 ++- tests/charts/api_tests.py | 43 ++++++- tests/charts/schema_tests.py | 2 +- tests/db_engine_specs/hive_tests.py | 1 - tests/query_context_tests.py | 1 - 12 files changed, 339 insertions(+), 99 deletions(-) create mode 100644 superset/common/query_actions.py diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py index 6105d1b3ed4b9..185e2aad8984c 100644 --- a/superset/charts/schemas.py +++ b/superset/charts/schemas.py @@ -887,10 +887,22 @@ class AnnotationLayerSchema(Schema): ) +class ChartDataDatasourceSchema(Schema): + description = "Chart datasource" + id = fields.Integer(description="Datasource id", required=True,) + type = fields.String( + description="Datasource type", + validate=validate.OneOf(choices=("druid", "table")), + ) + + class ChartDataQueryObjectSchema(Schema): class Meta: # pylint: disable=too-few-public-methods unknown = EXCLUDE + datasource = fields.Nested(ChartDataDatasourceSchema, allow_none=True) + result_type = EnumField(ChartDataResultType, by_value=True, allow_none=True) + annotation_layers = fields.List( fields.Nested(AnnotationLayerSchema), description="Annotation layers to apply to chart", @@ -971,10 +983,10 @@ class Meta: # pylint: disable=too-few-public-methods description="Metric used to limit timeseries queries by.", allow_none=True, ) row_limit = fields.Integer( - description='Maximum row count. Default: `config["ROW_LIMIT"]`', + description='Maximum row count (0=disabled). Default: `config["ROW_LIMIT"]`', allow_none=True, validate=[ - Range(min=1, error=_("`row_limit` must be greater than or equal to 1")) + Range(min=0, error=_("`row_limit` must be greater than or equal to 0")) ], ) row_offset = fields.Integer( @@ -1038,14 +1050,9 @@ class Meta: # pylint: disable=too-few-public-methods values=fields.String(description="The value of the query parameter"), allow_none=True, ) - - -class ChartDataDatasourceSchema(Schema): - description = "Chart datasource" - id = fields.Integer(description="Datasource id", required=True,) - type = fields.String( - description="Datasource type", - validate=validate.OneOf(choices=("druid", "table")), + is_rowcount = fields.Boolean( + description="Should the rowcount of the actual query be returned", + allow_none=True, ) diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py new file mode 100644 index 0000000000000..dd4121e49e52c --- /dev/null +++ b/superset/common/query_actions.py @@ -0,0 +1,182 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import copy +import math +from typing import Any, Callable, cast, Dict, List, Optional, TYPE_CHECKING + +from flask_babel import _ + +from superset import app +from superset.connectors.base.models import BaseDatasource +from superset.exceptions import QueryObjectValidationError +from superset.utils.core import ( + ChartDataResultType, + extract_column_dtype, + extract_dataframe_dtypes, + get_time_filter_status, + QueryStatus, +) + +if TYPE_CHECKING: + from superset.common.query_context import QueryContext + from superset.common.query_object import QueryObject + +config = app.config + + +def _get_datasource( + query_context: "QueryContext", query_obj: "QueryObject" +) -> BaseDatasource: + return query_obj.datasource or query_context.datasource + + +def _get_columns( + query_context: "QueryContext", query_obj: "QueryObject", _: bool +) -> Dict[str, Any]: + datasource = _get_datasource(query_context, query_obj) + return { + "data": [ + { + "column_name": col.column_name, + "verbose_name": col.verbose_name, + "dtype": extract_column_dtype(col), + } + for col in datasource.columns + ] + } + + +def _get_timegrains( + query_context: "QueryContext", query_obj: "QueryObject", _: bool +) -> Dict[str, Any]: + datasource = _get_datasource(query_context, query_obj) + return { + "data": [ + { + "name": grain.name, + "function": grain.function, + "duration": grain.duration, + } + for grain in datasource.database.grains() + ] + } + + +def _get_query( + query_context: "QueryContext", query_obj: "QueryObject", _: bool, +) -> Dict[str, Any]: + datasource = _get_datasource(query_context, query_obj) + return { + "query": datasource.get_query_str(query_obj.to_dict()), + "language": datasource.query_language, + } + + +def _get_full( + query_context: "QueryContext", + query_obj: "QueryObject", + force_cached: Optional[bool] = False, +) -> Dict[str, Any]: + datasource = _get_datasource(query_context, query_obj) + result_type = query_obj.result_type or query_context.result_type + payload = query_context.get_df_payload(query_obj, force_cached=force_cached) + df = payload["df"] + status = payload["status"] + if status != QueryStatus.FAILED: + payload["colnames"] = list(df.columns) + payload["coltypes"] = extract_dataframe_dtypes(df) + payload["data"] = query_context.get_data(df) + del payload["df"] + + filters = query_obj.filter + filter_columns = cast(List[str], [flt.get("col") for flt in filters]) + columns = set(datasource.column_names) + applied_time_columns, rejected_time_columns = get_time_filter_status( + datasource, query_obj.applied_time_extras + ) + payload["applied_filters"] = [ + {"column": col} for col in filter_columns if col in columns + ] + applied_time_columns + payload["rejected_filters"] = [ + {"reason": "not_in_datasource", "column": col} + for col in filter_columns + if col not in columns + ] + rejected_time_columns + + if result_type == ChartDataResultType.RESULTS and status != QueryStatus.FAILED: + return {"data": payload["data"]} + return payload + + +def _get_samples( + query_context: "QueryContext", query_obj: "QueryObject", force_cached: bool = False +) -> Dict[str, Any]: + datasource = _get_datasource(query_context, query_obj) + row_limit = query_obj.row_limit or math.inf + query_obj = copy.copy(query_obj) + query_obj.is_timeseries = False + query_obj.orderby = [] + query_obj.groupby = [] + query_obj.metrics = [] + query_obj.post_processing = [] + query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"]) + query_obj.row_offset = 0 + query_obj.columns = [o.column_name for o in datasource.columns] + return _get_full(query_context, query_obj, force_cached) + + +def _get_results( + query_context: "QueryContext", query_obj: "QueryObject", force_cached: bool = False +) -> Dict[str, Any]: + payload = _get_full(query_context, query_obj, force_cached) + return {"data": payload["data"]} + + +_result_type_functions: Dict[ + ChartDataResultType, Callable[["QueryContext", "QueryObject", bool], Dict[str, Any]] +] = { + ChartDataResultType.COLUMNS: _get_columns, + ChartDataResultType.TIMEGRAINS: _get_timegrains, + ChartDataResultType.QUERY: _get_query, + ChartDataResultType.SAMPLES: _get_samples, + ChartDataResultType.FULL: _get_full, + ChartDataResultType.RESULTS: _get_results, +} + + +def get_query_results( + result_type: ChartDataResultType, + query_context: "QueryContext", + query_obj: "QueryObject", + force_cached: bool, +) -> Dict[str, Any]: + """ + Return result payload for a chart data request. + + :param result_type: the type of result to return + :param query_context: query context to which the query object belongs + :param query_obj: query object for which to retrieve the results + :param force_cached: should results be forcefully retrieved from cache + :raises QueryObjectValidationError: if an unsupported result type is requested + :return: JSON serializable result payload + """ + result_func = _result_type_functions.get(result_type) + if result_func: + return result_func(query_context, query_obj, force_cached) + raise QueryObjectValidationError( + _("Invalid result type: %(result_type)", result_type=result_type) + ) diff --git a/superset/common/query_context.py b/superset/common/query_context.py index be1e7b7ca07da..58b5238cbc6fb 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -14,10 +14,8 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -import copy import logging -import math -from typing import Any, cast, ClassVar, Dict, List, Optional, Union +from typing import Any, ClassVar, Dict, List, Optional, Union import numpy as np import pandas as pd @@ -26,6 +24,7 @@ from superset import app, db, is_feature_enabled from superset.annotation_layers.dao import AnnotationLayerDAO from superset.charts.dao import ChartDAO +from superset.common.query_actions import get_query_results from superset.common.query_object import QueryObject from superset.connectors.base.models import BaseDatasource from superset.connectors.connector_registry import ConnectorRegistry @@ -36,9 +35,18 @@ ) from superset.extensions import cache_manager, security_manager from superset.stats_logger import BaseStatsLogger -from superset.utils import core as utils from superset.utils.cache import generate_cache_key, set_and_log_cache -from superset.utils.core import DTTM_ALIAS +from superset.utils.core import ( + ChartDataResultFormat, + ChartDataResultType, + DatasourceDict, + DTTM_ALIAS, + error_msg_from_exception, + get_column_names_from_metrics, + get_stacktrace, + normalize_dttm_col, + QueryStatus, +) from superset.views.utils import get_viz config = app.config @@ -59,19 +67,19 @@ class QueryContext: queries: List[QueryObject] force: bool custom_cache_timeout: Optional[int] - result_type: utils.ChartDataResultType - result_format: utils.ChartDataResultFormat + result_type: ChartDataResultType + result_format: ChartDataResultFormat # TODO: Type datasource and query_object dictionary with TypedDict when it becomes # a vanilla python type https://github.com/python/mypy/issues/5288 def __init__( # pylint: disable=too-many-arguments self, - datasource: Dict[str, Any], + datasource: DatasourceDict, queries: List[Dict[str, Any]], force: bool = False, custom_cache_timeout: Optional[int] = None, - result_type: Optional[utils.ChartDataResultType] = None, - result_format: Optional[utils.ChartDataResultFormat] = None, + result_type: Optional[ChartDataResultType] = None, + result_format: Optional[ChartDataResultFormat] = None, ) -> None: self.datasource = ConnectorRegistry.get_datasource( str(datasource["type"]), int(datasource["id"]), db.session @@ -79,8 +87,8 @@ def __init__( # pylint: disable=too-many-arguments self.queries = [QueryObject(**query_obj) for query_obj in queries] self.force = force self.custom_cache_timeout = custom_cache_timeout - self.result_type = result_type or utils.ChartDataResultType.FULL - self.result_format = result_format or utils.ChartDataResultFormat.JSON + self.result_type = result_type or ChartDataResultType.FULL + self.result_format = result_format or ChartDataResultFormat.JSON self.cache_values = { "datasource": datasource, "queries": queries, @@ -111,7 +119,7 @@ def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]: # If the datetime format is unix, the parse will use the corresponding # parsing logic if not df.empty: - df = utils.normalize_dttm_col( + df = normalize_dttm_col( df=df, timestamp_format=timestamp_format, offset=self.datasource.offset, @@ -141,77 +149,24 @@ def df_metrics_to_num(df: pd.DataFrame, query_object: QueryObject) -> None: df[col] = df[col].infer_objects() def get_data(self, df: pd.DataFrame,) -> Union[str, List[Dict[str, Any]]]: - if self.result_format == utils.ChartDataResultFormat.CSV: + if self.result_format == ChartDataResultFormat.CSV: include_index = not isinstance(df.index, pd.RangeIndex) result = df.to_csv(index=include_index, **config["CSV_EXPORT"]) return result or "" return df.to_dict(orient="records") - def get_single_payload( - self, query_obj: QueryObject, force_cached: Optional[bool] = False, - ) -> Dict[str, Any]: - """Return results payload for a single quey""" - if self.result_type == utils.ChartDataResultType.QUERY: - return { - "query": self.datasource.get_query_str(query_obj.to_dict()), - "language": self.datasource.query_language, - } - - if self.result_type == utils.ChartDataResultType.SAMPLES: - row_limit = query_obj.row_limit or math.inf - query_obj = copy.copy(query_obj) - query_obj.is_timeseries = False - query_obj.orderby = [] - query_obj.groupby = [] - query_obj.metrics = [] - query_obj.post_processing = [] - query_obj.row_limit = min(row_limit, config["SAMPLES_ROW_LIMIT"]) - query_obj.row_offset = 0 - query_obj.columns = [o.column_name for o in self.datasource.columns] - - payload = self.get_df_payload(query_obj, force_cached=force_cached) - df = payload["df"] - status = payload["status"] - if status != utils.QueryStatus.FAILED: - payload["colnames"] = list(df.columns) - payload["coltypes"] = utils.extract_dataframe_dtypes(df) - payload["data"] = self.get_data(df) - del payload["df"] - - filters = query_obj.filter - filter_columns = cast(List[str], [flt.get("col") for flt in filters]) - columns = set(self.datasource.column_names) - applied_time_columns, rejected_time_columns = utils.get_time_filter_status( - self.datasource, query_obj.applied_time_extras - ) - payload["applied_filters"] = [ - {"column": col} for col in filter_columns if col in columns - ] + applied_time_columns - payload["rejected_filters"] = [ - {"reason": "not_in_datasource", "column": col} - for col in filter_columns - if col not in columns - ] + rejected_time_columns - - if ( - self.result_type == utils.ChartDataResultType.RESULTS - and status != utils.QueryStatus.FAILED - ): - return {"data": payload["data"]} - return payload - def get_payload( - self, - cache_query_context: Optional[bool] = False, - force_cached: Optional[bool] = False, + self, cache_query_context: Optional[bool] = False, force_cached: bool = False, ) -> Dict[str, Any]: """Returns the query results with both metadata and data""" # Get all the payloads from the QueryObjects query_results = [ - self.get_single_payload(query_object, force_cached=force_cached) - for query_object in self.queries + get_query_results( + query_obj.result_type or self.result_type, self, query_obj, force_cached + ) + for query_obj in self.queries ] return_value = {"queries": query_results} @@ -326,7 +281,7 @@ def get_viz_annotation_data( payload = viz_obj.get_payload() return payload["data"] except SupersetException as ex: - raise QueryObjectValidationError(utils.error_msg_from_exception(ex)) + raise QueryObjectValidationError(error_msg_from_exception(ex)) def get_annotation_data(self, query_obj: QueryObject) -> Dict[str, Any]: """ @@ -368,13 +323,13 @@ def get_df_payload( # pylint: disable=too-many-statements,too-many-locals df = cache_value["df"] query = cache_value["query"] annotation_data = cache_value.get("annotation_data", {}) - status = utils.QueryStatus.SUCCESS + status = QueryStatus.SUCCESS is_loaded = True stats_logger.incr("loaded_from_cache") except KeyError as ex: logger.exception(ex) logger.error( - "Error reading cache: %s", utils.error_msg_from_exception(ex) + "Error reading cache: %s", error_msg_from_exception(ex) ) logger.info("Serving from cache") @@ -390,7 +345,7 @@ def get_df_payload( # pylint: disable=too-many-statements,too-many-locals col for col in query_obj.columns + query_obj.groupby - + utils.get_column_names_from_metrics(query_obj.metrics) + + get_column_names_from_metrics(query_obj.metrics) if col not in self.datasource.column_names and col != DTTM_ALIAS ] if invalid_columns: @@ -407,22 +362,22 @@ def get_df_payload( # pylint: disable=too-many-statements,too-many-locals df = query_result["df"] annotation_data = self.get_annotation_data(query_obj) - if status != utils.QueryStatus.FAILED: + if status != QueryStatus.FAILED: stats_logger.incr("loaded_from_source") if not self.force: stats_logger.incr("loaded_from_source_without_force") is_loaded = True except QueryObjectValidationError as ex: error_message = str(ex) - status = utils.QueryStatus.FAILED + status = QueryStatus.FAILED except Exception as ex: # pylint: disable=broad-except logger.exception(ex) if not error_message: error_message = str(ex) - status = utils.QueryStatus.FAILED - stacktrace = utils.get_stacktrace() + status = QueryStatus.FAILED + stacktrace = get_stacktrace() - if is_loaded and cache_key and status != utils.QueryStatus.FAILED: + if is_loaded and cache_key and status != QueryStatus.FAILED: set_and_log_cache( cache_manager.data_cache, cache_key, diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 2d4772cf5749c..69baacaa270dc 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -24,11 +24,15 @@ from flask_babel import gettext as _ from pandas import DataFrame -from superset import app +from superset import app, db +from superset.connectors.base.models import BaseDatasource +from superset.connectors.connector_registry import ConnectorRegistry from superset.exceptions import QueryObjectValidationError from superset.typing import Metric from superset.utils import pandas_postprocessing from superset.utils.core import ( + ChartDataResultType, + DatasourceDict, DTTM_ALIAS, find_duplicates, get_metric_names, @@ -86,9 +90,14 @@ class QueryObject: columns: List[str] orderby: List[List[str]] post_processing: List[Dict[str, Any]] + datasource: Optional[BaseDatasource] + result_type: Optional[ChartDataResultType] + is_rowcount: bool def __init__( self, + datasource: Optional[DatasourceDict] = None, + result_type: Optional[ChartDataResultType] = None, annotation_layers: Optional[List[Dict[str, Any]]] = None, applied_time_extras: Optional[Dict[str, str]] = None, granularity: Optional[str] = None, @@ -107,8 +116,16 @@ def __init__( columns: Optional[List[str]] = None, orderby: Optional[List[List[str]]] = None, post_processing: Optional[List[Optional[Dict[str, Any]]]] = None, + is_rowcount: bool = False, **kwargs: Any, ): + self.is_rowcount = is_rowcount + self.datasource = None + if datasource: + self.datasource = ConnectorRegistry.get_datasource( + str(datasource["type"]), int(datasource["id"]), db.session + ) + self.result_type = result_type annotation_layers = annotation_layers or [] metrics = metrics or [] columns = columns or [] @@ -156,7 +173,7 @@ def __init__( for metric in metrics ] - self.row_limit = row_limit or config["ROW_LIMIT"] + self.row_limit = config["ROW_LIMIT"] if row_limit is None else row_limit self.row_offset = row_offset or 0 self.filter = filters or [] self.timeseries_limit = timeseries_limit @@ -247,6 +264,7 @@ def to_dict(self) -> Dict[str, Any]: "groupby": self.groupby, "from_dttm": self.from_dttm, "to_dttm": self.to_dttm, + "is_rowcount": self.is_rowcount, "is_timeseries": self.is_timeseries, "metrics": self.metrics, "row_limit": self.row_limit, @@ -271,6 +289,10 @@ def cache_key(self, **extra: Any) -> str: """ cache_dict = self.to_dict() cache_dict.update(extra) + if self.datasource: + cache_dict["datasource"] = self.datasource.uid + if self.result_type: + cache_dict["result_type"] = self.result_type for k in ["from_dttm", "to_dttm"]: del cache_dict[k] diff --git a/superset/connectors/druid/models.py b/superset/connectors/druid/models.py index ee0e9f9117d46..078369a7f3744 100644 --- a/superset/connectors/druid/models.py +++ b/superset/connectors/druid/models.py @@ -1138,8 +1138,13 @@ def run_query( # druid phase: int = 2, client: Optional["PyDruid"] = None, order_desc: bool = True, + is_rowcount: bool = False, ) -> str: """Runs a query against Druid and returns a dataframe.""" + # is_rowcount is only supported on SQL connector + if is_rowcount: + raise SupersetException("is_rowcount is not supported on Druid connector") + # TODO refactor into using a TBD Query object client = client or self.cluster.get_pydruid_client() row_limit = row_limit or conf.get("ROW_LIMIT") diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index a64a62869f701..e1e1b9bad8ad9 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -898,6 +898,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma orderby: Optional[List[Tuple[ColumnElement, bool]]] = None, extras: Optional[Dict[str, Any]] = None, order_desc: bool = True, + is_rowcount: bool = False, ) -> SqlaQuery: """Querying any sqla table from this common interface""" template_kwargs = { @@ -1253,10 +1254,21 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma result.df, dimensions, groupby_exprs_sans_timestamp ) qry = qry.where(top_groups) + if is_rowcount: + if not db_engine_spec.allows_subqueries: + raise QueryObjectValidationError( + _("Database does not support subqueries") + ) + label = "rowcount" + col = self.make_sqla_column_compatible(literal_column("COUNT(*)"), label) + qry = select([col]).select_from(qry.select_from(tbl).alias("rowcount_qry")) + labels_expected = [label] + else: + qry = qry.select_from(tbl) return SqlaQuery( extra_cache_keys=extra_cache_keys, labels_expected=labels_expected, - sqla_query=qry.select_from(tbl), + sqla_query=qry, prequeries=prequeries, ) diff --git a/superset/utils/cache.py b/superset/utils/cache.py index 66da4688b50ef..0abd76a796790 100644 --- a/superset/utils/cache.py +++ b/superset/utils/cache.py @@ -35,6 +35,7 @@ stats_logger: BaseStatsLogger = config["STATS_LOGGER"] logger = logging.getLogger(__name__) + # TODO: DRY up cache key code def json_dumps(obj: Any, sort_keys: bool = False) -> str: return json.dumps(obj, default=json_int_dttm_ser, sort_keys=sort_keys) diff --git a/superset/utils/core.py b/superset/utils/core.py index 4ff3146cdcd4d..1b5ef70ee6086 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -83,6 +83,7 @@ from sqlalchemy.engine.reflection import Inspector from sqlalchemy.sql.type_api import Variant from sqlalchemy.types import TEXT, TypeDecorator +from typing_extensions import TypedDict import _thread # pylint: disable=C0411 from superset.errors import ErrorLevel, SupersetErrorType @@ -100,7 +101,7 @@ pass if TYPE_CHECKING: - from superset.connectors.base.models import BaseDatasource + from superset.connectors.base.models import BaseColumn, BaseDatasource from superset.models.core import Database @@ -163,10 +164,17 @@ class ChartDataResultType(str, Enum): Chart data response type """ + COLUMNS = "columns" FULL = "full" QUERY = "query" RESULTS = "results" SAMPLES = "samples" + TIMEGRAINS = "timegrains" + + +class DatasourceDict(TypedDict): + type: str + id: int class ExtraFiltersTimeColumnType(str, Enum): @@ -1490,6 +1498,15 @@ def extract_dataframe_dtypes(df: pd.DataFrame) -> List[GenericDataType]: return generic_types +def extract_column_dtype(col: "BaseColumn") -> GenericDataType: + if col.is_temporal: + return GenericDataType.TEMPORAL + if col.is_numeric: + return GenericDataType.NUMERIC + # TODO: add check for boolean data type when proper support is added + return GenericDataType.STRING + + def indexed( items: List[Any], key: Union[str, Callable[[Any], Any]] ) -> Dict[Any, List[Any]]: diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index a6030b675e723..ff402cc3ccc72 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -1102,7 +1102,7 @@ def test_chart_data_default_row_limit(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @mock.patch( - "superset.common.query_context.config", {**app.config, "SAMPLES_ROW_LIMIT": 5}, + "superset.common.query_actions.config", {**app.config, "SAMPLES_ROW_LIMIT": 5}, ) def test_chart_data_default_sample_limit(self): """ @@ -1698,3 +1698,44 @@ def quote_name(self, name: str): name ) return name + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_rowcount(self): + """ + Chart data API: Query total rows + """ + self.login(username="admin") + request_payload = get_query_context("birth_names") + request_payload["queries"][0]["is_rowcount"] = True + request_payload["queries"][0]["groupby"] = ["name"] + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + response_payload = json.loads(rv.data.decode("utf-8")) + result = response_payload["result"][0] + expected_row_count = self.get_expected_row_count("client_id_4") + self.assertEqual(result["data"][0]["rowcount"], expected_row_count) + + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_chart_data_timegrains(self): + """ + Chart data API: Query timegrains and columns + """ + self.login(username="admin") + request_payload = get_query_context("birth_names") + request_payload["queries"] = [ + {"result_type": utils.ChartDataResultType.TIMEGRAINS}, + {"result_type": utils.ChartDataResultType.COLUMNS}, + ] + rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") + response_payload = json.loads(rv.data.decode("utf-8")) + timegrain_result = response_payload["result"][0] + column_result = response_payload["result"][1] + assert list(timegrain_result["data"][0].keys()) == [ + "name", + "function", + "duration", + ] + assert list(column_result["data"][0].keys()) == [ + "column_name", + "verbose_name", + "dtype", + ] diff --git a/tests/charts/schema_tests.py b/tests/charts/schema_tests.py index b8d436384c371..dc191904e85d2 100644 --- a/tests/charts/schema_tests.py +++ b/tests/charts/schema_tests.py @@ -48,7 +48,7 @@ def test_query_context_limit_and_offset(self): self.assertEqual(query_object.row_offset, 200) # too low limit and offset - payload["queries"][0]["row_limit"] = 0 + payload["queries"][0]["row_limit"] = -1 payload["queries"][0]["row_offset"] = -1 with self.assertRaises(ValidationError) as context: _ = ChartDataQueryContextSchema().load(payload) diff --git a/tests/db_engine_specs/hive_tests.py b/tests/db_engine_specs/hive_tests.py index ac57f13f94b3a..fd7acef43bca9 100644 --- a/tests/db_engine_specs/hive_tests.py +++ b/tests/db_engine_specs/hive_tests.py @@ -271,7 +271,6 @@ def test_get_create_table_stmt() -> None: location = "s3a://directory/table" from unittest import TestCase - TestCase.maxDiff = None assert HiveEngineSpec.get_create_table_stmt( table, schema_def, location, ",", 0, [""] ) == ( diff --git a/tests/query_context_tests.py b/tests/query_context_tests.py index 6045e7170b7a6..67fb8d59ed411 100644 --- a/tests/query_context_tests.py +++ b/tests/query_context_tests.py @@ -313,7 +313,6 @@ def test_query_object_unknown_fields(self): Ensure that query objects with unknown fields don't raise an Exception and have an identical cache key as one without the unknown field """ - self.maxDiff = None self.login(username="admin") payload = get_query_context("birth_names") query_context = ChartDataQueryContextSchema().load(payload)