Skip to content

Commit

Permalink
perf: memoize db_engine_spec in database (#14638)
Browse files Browse the repository at this point in the history
* perf: memoize db_engine_spec in sqla table classes

* remove extended cypress timeouts
  • Loading branch information
villebro authored May 14, 2021
1 parent bf90885 commit 97c9e37
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ export interface ChartSpec {

export function getChartGridComponent({ name, viz }: ChartSpec) {
return cy
.get(`[data-test="chart-grid-component"][data-test-chart-name="${name}"]`, {
timeout: 30000,
})
.get(`[data-test="chart-grid-component"][data-test-chart-name="${name}"]`)
.should('have.attr', 'data-test-viz-type', viz);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ describe('Dashboard load', () => {

it('should load in edit/standalone mode', () => {
cy.visit(`${WORLD_HEALTH_DASHBOARD}?edit=true&standalone=true`);
cy.get('[data-test="discard-changes-button"]', { timeout: 10000 }).should(
'be.visible',
);
cy.get('[data-test="discard-changes-button"]').should('be.visible');
cy.get('#app-menu').should('not.exist');
});

Expand Down
2 changes: 1 addition & 1 deletion superset-frontend/src/datasource/ChangeDatasourceModal.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ const ChangeDatasourceModal: FunctionComponent<ChangeDatasourceModalProps> = ({

const handleChangeConfirm = () => {
SupersetClient.get({
endpoint: `/datasource/get/${confirmedDataset?.type}/${confirmedDataset?.id}`,
endpoint: `/datasource/get/${confirmedDataset?.type}/${confirmedDataset?.id}/`,
})
.then(({ json }) => {
onDatasourceSave(json);
Expand Down
64 changes: 30 additions & 34 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from contextlib import closing
from dataclasses import dataclass, field # pylint: disable=wrong-import-order
from datetime import datetime, timedelta
from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, Union
from typing import Any, Dict, Hashable, List, NamedTuple, Optional, Tuple, Type, Union

import pandas as pd
import sqlalchemy as sa
Expand Down Expand Up @@ -57,7 +57,7 @@

from superset import app, db, is_feature_enabled, security_manager
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
from superset.db_engine_specs.base import TimestampExpression
from superset.db_engine_specs.base import BaseEngineSpec, TimestampExpression
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import (
QueryObjectValidationError,
Expand Down Expand Up @@ -196,30 +196,21 @@ def is_boolean(self) -> bool:
"""
Check if the column has a boolean datatype.
"""
column_spec = self.table.database.db_engine_spec.get_column_spec(self.type)
if column_spec is None:
return False
return column_spec.generic_type == GenericDataType.BOOLEAN
return self.type_generic == GenericDataType.BOOLEAN

@property
def is_numeric(self) -> bool:
"""
Check if the column has a numeric datatype.
"""
column_spec = self.table.database.db_engine_spec.get_column_spec(self.type)
if column_spec is None:
return False
return column_spec.generic_type == GenericDataType.NUMERIC
return self.type_generic == GenericDataType.NUMERIC

@property
def is_string(self) -> bool:
"""
Check if the column has a string datatype.
"""
column_spec = self.table.database.db_engine_spec.get_column_spec(self.type)
if column_spec is None:
return False
return column_spec.generic_type == GenericDataType.STRING
return self.type_generic == GenericDataType.STRING

@property
def is_temporal(self) -> bool:
Expand All @@ -231,14 +222,20 @@ def is_temporal(self) -> bool:
"""
if self.is_dttm is not None:
return self.is_dttm
column_spec = self.table.database.db_engine_spec.get_column_spec(self.type)
if column_spec is None:
return False
return column_spec.is_dttm
return self.type_generic == GenericDataType.TEMPORAL

@property
def db_engine_spec(self) -> Type[BaseEngineSpec]:
return self.table.db_engine_spec

@property
def type_generic(self) -> Optional[utils.GenericDataType]:
column_spec = self.db_engine_spec.get_column_spec(self.type)
return column_spec.generic_type if column_spec else None

def get_sqla_col(self, label: Optional[str] = None) -> Column:
label = label or self.column_name
db_engine_spec = self.table.database.db_engine_spec
db_engine_spec = self.db_engine_spec
column_spec = db_engine_spec.get_column_spec(self.type)
type_ = column_spec.sqla_type if column_spec else None
if self.expression:
Expand Down Expand Up @@ -290,7 +287,6 @@ def get_timestamp_expression(
"""
label = label or utils.DTTM_ALIAS

db_ = self.table.database
pdf = self.python_date_format
is_epoch = pdf in ("epoch_s", "epoch_ms")
if not self.expression and not time_grain and not is_epoch:
Expand All @@ -300,7 +296,7 @@ def get_timestamp_expression(
col = literal_column(self.expression)
else:
col = column(self.column_name)
time_expr = db_.db_engine_spec.get_timestamp_expr(
time_expr = self.db_engine_spec.get_timestamp_expr(
col, pdf, time_grain, self.type
)
return self.table.make_sqla_column_compatible(time_expr, label)
Expand All @@ -313,11 +309,7 @@ def dttm_sql_literal(
],
) -> str:
"""Convert datetime object to a SQL expression string"""
sql = (
self.table.database.db_engine_spec.convert_dttm(self.type, dttm)
if self.type
else None
)
sql = self.db_engine_spec.convert_dttm(self.type, dttm) if self.type else None

if sql:
return sql
Expand Down Expand Up @@ -527,6 +519,10 @@ class SqlaTable( # pylint: disable=too-many-public-methods,too-many-instance-at
def __repr__(self) -> str:
return self.name

@property
def db_engine_spec(self) -> Type[BaseEngineSpec]:
return self.database.db_engine_spec

@property
def changed_by_name(self) -> str:
if not self.changed_by:
Expand Down Expand Up @@ -636,7 +632,7 @@ def sql_url(self) -> str:
return self.database.sql_url + "?table_name=" + str(self.table_name)

def external_metadata(self) -> List[Dict[str, str]]:
db_engine_spec = self.database.db_engine_spec
db_engine_spec = self.db_engine_spec
if self.sql:
engine = self.database.get_sqla_engine(schema=self.schema)
sql = self.get_template_processor().process_template(self.sql)
Expand Down Expand Up @@ -815,9 +811,9 @@ def get_from_clause(

from_sql = self.get_rendered_sql(template_processor)
parsed_query = ParsedQuery(from_sql)
db_engine_spec = self.database.db_engine_spec
if not (
parsed_query.is_unknown() or db_engine_spec.is_readonly_query(parsed_query)
parsed_query.is_unknown()
or self.db_engine_spec.is_readonly_query(parsed_query)
):
raise QueryObjectValidationError(
_("Virtual dataset query must be read-only")
Expand Down Expand Up @@ -889,7 +885,7 @@ def make_sqla_column_compatible(
:return: either a sql alchemy column or label instance if supported by engine
"""
label_expected = label or sqla_col.name
db_engine_spec = self.database.db_engine_spec
db_engine_spec = self.db_engine_spec
# add quotes to tables
if db_engine_spec.allows_alias_in_select:
label = db_engine_spec.make_label_compatible(label_expected)
Expand All @@ -909,7 +905,7 @@ def make_orderby_compatible(
the same as a source column. In this case, we update the SELECT alias to
another name to avoid the conflict.
"""
if self.database.db_engine_spec.allows_alias_to_source_column:
if self.db_engine_spec.allows_alias_to_source_column:
return

def is_alias_used_in_orderby(col: ColumnElement) -> bool:
Expand Down Expand Up @@ -990,7 +986,7 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
extra_cache_keys: List[Any] = []
template_kwargs["extra_cache_keys"] = extra_cache_keys
template_processor = self.get_template_processor(**template_kwargs)
db_engine_spec = self.database.db_engine_spec
db_engine_spec = self.db_engine_spec
prequeries: List[str] = []
orderby = orderby or []
extras = extras or {}
Expand Down Expand Up @@ -1469,7 +1465,7 @@ def assign_column_label(df: pd.DataFrame) -> Optional[pd.DataFrame]:
logger.warning(
"Query %s on schema %s failed", sql, self.schema, exc_info=True
)
db_engine_spec = self.database.db_engine_spec
db_engine_spec = self.db_engine_spec
errors = [
dataclasses.asdict(error) for error in db_engine_spec.extract_errors(ex)
]
Expand Down Expand Up @@ -1497,7 +1493,7 @@ def fetch_metadata(self, commit: bool = True) -> MetadataResult:
new_columns = self.external_metadata()
metrics = []
any_date_col = None
db_engine_spec = self.database.db_engine_spec
db_engine_spec = self.db_engine_spec
old_columns = db.session.query(TableColumn).filter(TableColumn.table == self)

old_columns_by_name: Dict[str, TableColumn] = {
Expand Down
1 change: 1 addition & 0 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,7 @@ def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
)

@classmethod
@utils.memoized
def get_column_spec(
cls,
native_type: Optional[str],
Expand Down
4 changes: 2 additions & 2 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,10 +568,10 @@ def get_all_schema_names(

@property
def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]:
engines = db_engine_specs.get_engine_specs()
return engines.get(self.backend, db_engine_specs.BaseEngineSpec)
return self.get_db_engine_spec_for_backend(self.backend)

@classmethod
@utils.memoized
def get_db_engine_spec_for_backend(
cls, backend: str
) -> Type[db_engine_specs.BaseEngineSpec]:
Expand Down

0 comments on commit 97c9e37

Please sign in to comment.