diff --git a/superset/columns/__init__.py b/superset/columns/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/superset/columns/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/columns/models.py b/superset/columns/models.py new file mode 100644 index 0000000000000..039f73ff57579 --- /dev/null +++ b/superset/columns/models.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Column model. + +This model was introduced in SIP-68 (https://github.com/apache/superset/issues/14909), +and represents a "column" in a table or dataset. In addition to a column, new models for +tables, metrics, and datasets were also introduced. + +These models are not fully implemented, and shouldn't be used yet. +""" + +import sqlalchemy as sa +from flask_appbuilder import Model + +from superset.models.helpers import ( + AuditMixinNullable, + ExtraJSONMixin, + ImportExportMixin, +) + + +class Column( + Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, +): + """ + A "column". + + The definition of column here is overloaded: it can represent a physical column in a + database relation (table or view); a computed/derived column on a dataset; or an + aggregation expression representing a metric. + """ + + __tablename__ = "sl_columns" + + id = sa.Column(sa.Integer, primary_key=True) + + # We use ``sa.Text`` for these attributes because (1) in modern databases the + # performance is the same as ``VARCHAR``[1] and (2) because some table names can be + # **really** long (eg, Google Sheets URLs). + # + # [1] https://www.postgresql.org/docs/9.1/datatype-character.html + name = sa.Column(sa.Text) + type = sa.Column(sa.Text) + + # Columns are defined by expressions. For tables, these are the actual columns names, + # and should match the ``name`` attribute. For datasets, these can be any valid SQL + # expression. If the SQL expression is an aggregation the column is a metric, + # otherwise it's a computed column. + expression = sa.Column(sa.Text) + + # Does the expression point directly to a physical column? + is_physical = sa.Column(sa.Boolean, default=True) + + # Additional metadata describing the column. + description = sa.Column(sa.Text) + warning_text = sa.Column(sa.Text) + unit = sa.Column(sa.Text) + + # Is this a time column? Useful for plotting time series. + is_temporal = sa.Column(sa.Boolean, default=False) + + # Is this a spatial column? This could be leveraged in the future for spatial + # visualizations. + is_spatial = sa.Column(sa.Boolean, default=False) + + # Is this column a partition? Useful for scheduling queries and previewing the latest + # data. + is_partition = sa.Column(sa.Boolean, default=False) + + # Is this column an aggregation (metric)? + is_aggregation = sa.Column(sa.Boolean, default=False) + + # Assuming the column is an aggregation, is it additive? Useful for determining which + # aggregations can be done on the metric. Eg, ``COUNT(DISTINCT user_id)`` is not + # additive, so it shouldn't be used in a ``SUM``. + is_additive = sa.Column(sa.Boolean, default=False) + + # Is an increase desired? Useful for displaying the results of A/B tests, or setting + # up alerts. Eg, this is true for "revenue", but false for "latency". + is_increase_desired = sa.Column(sa.Boolean, default=True) + + # Column is managed externally and should be read-only inside Superset + is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) + external_url = sa.Column(sa.Text, nullable=True) diff --git a/superset/columns/schemas.py b/superset/columns/schemas.py new file mode 100644 index 0000000000000..5368bfbcca7ec --- /dev/null +++ b/superset/columns/schemas.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Schema for the column model. + +This model was introduced in SIP-68 (https://github.com/apache/superset/issues/14909), +and represents a "column" in a table or dataset. In addition to a column, new models for +tables, metrics, and datasets were also introduced. + +These models are not fully implemented, and shouldn't be used yet. +""" + +from marshmallow_sqlalchemy import SQLAlchemyAutoSchema + +from superset.columns.models import Column + + +class ColumnSchema(SQLAlchemyAutoSchema): + """ + Schema for the ``Column`` model. + """ + + class Meta: # pylint: disable=too-few-public-methods + model = Column + load_instance = True + include_relationships = True diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index ca1d4bc57a022..81a53e74dae3b 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=too-many-lines +# pylint: disable=too-many-lines, redefined-outer-name import dataclasses import json import logging @@ -53,6 +53,7 @@ desc, Enum, ForeignKey, + inspect, Integer, or_, select, @@ -71,12 +72,14 @@ from sqlalchemy.sql.selectable import Alias, TableClause from superset import app, db, is_feature_enabled, security_manager +from superset.columns.models import Column as NewColumn from superset.common.db_query_status import QueryStatus from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric from superset.connectors.sqla.utils import ( get_physical_table_metadata, get_virtual_table_metadata, ) +from superset.datasets.models import Dataset as NewDataset from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression from superset.exceptions import QueryObjectValidationError from superset.jinja_context import ( @@ -86,8 +89,14 @@ ) from superset.models.annotations import Annotation from superset.models.core import Database -from superset.models.helpers import AuditMixinNullable, CertificationMixin, QueryResult +from superset.models.helpers import ( + AuditMixinNullable, + CertificationMixin, + clone_model, + QueryResult, +) from superset.sql_parse import ParsedQuery +from superset.tables.models import Table as NewTable from superset.typing import AdhocColumn, AdhocMetric, Metric, OrderBy, QueryObjectDict from superset.utils import core as utils from superset.utils.core import ( @@ -104,6 +113,13 @@ VIRTUAL_TABLE_ALIAS = "virtual_table" +# a non-exhaustive set of additive metrics +ADDITIVE_METRIC_TYPES = { + "count", + "sum", + "doubleSum", +} + class SqlaQuery(NamedTuple): applied_template_filters: List[str] @@ -1830,23 +1846,474 @@ def before_update( raise Exception(get_dataset_exist_error_msg(target.full_name)) @staticmethod - def update_table( - _mapper: Mapper, _connection: Connection, obj: Union[SqlMetric, TableColumn] + def update_table( # pylint: disable=unused-argument + mapper: Mapper, connection: Connection, target: Union[SqlMetric, TableColumn] ) -> None: """ Forces an update to the table's changed_on value when a metric or column on the table is updated. This busts the cache key for all charts that use the table. - :param _mapper: Unused. - :param _connection: Unused. - :param obj: The metric or column that was updated. + :param mapper: Unused. + :param connection: Unused. + :param target: The metric or column that was updated. + """ + inspector = inspect(target) + session = inspector.session + + # get DB-specific conditional quoter for expressions that point to columns or + # table names + database = ( + target.table.database + or session.query(Database).filter_by(id=target.database_id).one() + ) + engine = database.get_sqla_engine(schema=target.table.schema) + conditional_quote = engine.dialect.identifier_preparer.quote + + session.execute(update(SqlaTable).where(SqlaTable.id == target.table.id)) + + # update ``Column`` model as well + dataset = ( + session.query(NewDataset).filter_by(sqlatable_id=target.table.id).one() + ) + + if isinstance(target, TableColumn): + columns = [ + column + for column in dataset.columns + if column.name == target.column_name + ] + if not columns: + return + + column = columns[0] + extra_json = json.loads(target.extra or "{}") + for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}: + value = getattr(target, attr) + if value: + extra_json[attr] = value + + column.name = target.column_name + column.type = target.type or "Unknown" + column.expression = target.expression or conditional_quote( + target.column_name + ) + column.description = target.description + column.is_temporal = target.is_dttm + column.is_physical = target.expression is None + column.extra_json = json.dumps(extra_json) if extra_json else None + + else: # SqlMetric + columns = [ + column + for column in dataset.columns + if column.name == target.metric_name + ] + if not columns: + return + + column = columns[0] + extra_json = json.loads(target.extra or "{}") + for attr in {"verbose_name", "metric_type", "d3format"}: + value = getattr(target, attr) + if value: + extra_json[attr] = value + + is_additive = ( + target.metric_type + and target.metric_type.lower() in ADDITIVE_METRIC_TYPES + ) + + column.name = target.metric_name + column.expression = target.expression + column.warning_text = target.warning_text + column.description = target.description + column.is_additive = is_additive + column.extra_json = json.dumps(extra_json) if extra_json else None + + @staticmethod + def after_insert( # pylint: disable=too-many-locals + mapper: Mapper, connection: Connection, target: "SqlaTable", + ) -> None: + """ + Shadow write the dataset to new models. + + The ``SqlaTable`` model is currently being migrated to two new models, ``Table`` + and ``Dataset``. In the first phase of the migration the new models are populated + whenever ``SqlaTable`` is modified (created, updated, or deleted). + + In the second phase of the migration reads will be done from the new models. + Finally, in the third phase of the migration the old models will be removed. + + For more context: https://github.com/apache/superset/issues/14909 + """ + # set permissions + security_manager.set_perm(mapper, connection, target) + + session = inspect(target).session + + # get DB-specific conditional quoter for expressions that point to columns or + # table names + database = ( + target.database + or session.query(Database).filter_by(id=target.database_id).one() + ) + engine = database.get_sqla_engine(schema=target.schema) + conditional_quote = engine.dialect.identifier_preparer.quote + + # create columns + columns = [] + for column in target.columns: + # ``is_active`` might be ``None`` at this point, but it defaults to ``True``. + if column.is_active is False: + continue + + extra_json = json.loads(column.extra or "{}") + for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}: + value = getattr(column, attr) + if value: + extra_json[attr] = value + + columns.append( + NewColumn( + name=column.column_name, + type=column.type or "Unknown", + expression=column.expression + or conditional_quote(column.column_name), + description=column.description, + is_temporal=column.is_dttm, + is_aggregation=False, + is_physical=column.expression is None, + is_spatial=False, + is_partition=False, + is_increase_desired=True, + extra_json=json.dumps(extra_json) if extra_json else None, + is_managed_externally=target.is_managed_externally, + external_url=target.external_url, + ), + ) + + # create metrics + for metric in target.metrics: + extra_json = json.loads(metric.extra or "{}") + for attr in {"verbose_name", "metric_type", "d3format"}: + value = getattr(metric, attr) + if value: + extra_json[attr] = value + + is_additive = ( + metric.metric_type + and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES + ) + + columns.append( + NewColumn( + name=metric.metric_name, + type="Unknown", # figuring this out would require a type inferrer + expression=metric.expression, + warning_text=metric.warning_text, + description=metric.description, + is_aggregation=True, + is_additive=is_additive, + is_physical=False, + is_spatial=False, + is_partition=False, + is_increase_desired=True, + extra_json=json.dumps(extra_json) if extra_json else None, + is_managed_externally=target.is_managed_externally, + external_url=target.external_url, + ), + ) + + # physical dataset + tables = [] + if target.sql is None: + physical_columns = [column for column in columns if column.is_physical] + + # create table + table = NewTable( + name=target.table_name, + schema=target.schema, + catalog=None, # currently not supported + database_id=target.database_id, + columns=physical_columns, + is_managed_externally=target.is_managed_externally, + external_url=target.external_url, + ) + tables.append(table) + + # virtual dataset + else: + # mark all columns as virtual (not physical) + for column in columns: + column.is_physical = False + + # find referenced tables + parsed = ParsedQuery(target.sql) + referenced_tables = parsed.tables + + # predicate for finding the referenced tables + predicate = or_( + *[ + and_( + NewTable.schema == (table.schema or target.schema), + NewTable.name == table.table, + ) + for table in referenced_tables + ] + ) + tables = session.query(NewTable).filter(predicate).all() + + # create the new dataset + dataset = NewDataset( + sqlatable_id=target.id, + name=target.table_name, + expression=target.sql or conditional_quote(target.table_name), + tables=tables, + columns=columns, + is_physical=target.sql is None, + is_managed_externally=target.is_managed_externally, + external_url=target.external_url, + ) + session.add(dataset) + + @staticmethod + def after_delete( # pylint: disable=unused-argument + mapper: Mapper, connection: Connection, target: "SqlaTable", + ) -> None: + """ + Shadow write the dataset to new models. + + The ``SqlaTable`` model is currently being migrated to two new models, ``Table`` + and ``Dataset``. In the first phase of the migration the new models are populated + whenever ``SqlaTable`` is modified (created, updated, or deleted). + + In the second phase of the migration reads will be done from the new models. + Finally, in the third phase of the migration the old models will be removed. + + For more context: https://github.com/apache/superset/issues/14909 + """ + session = inspect(target).session + dataset = ( + session.query(NewDataset).filter_by(sqlatable_id=target.id).one_or_none() + ) + if dataset: + session.delete(dataset) + + @staticmethod + def after_update( # pylint: disable=too-many-branches, too-many-locals, too-many-statements + mapper: Mapper, connection: Connection, target: "SqlaTable", + ) -> None: + """ + Shadow write the dataset to new models. + + The ``SqlaTable`` model is currently being migrated to two new models, ``Table`` + and ``Dataset``. In the first phase of the migration the new models are populated + whenever ``SqlaTable`` is modified (created, updated, or deleted). + + In the second phase of the migration reads will be done from the new models. + Finally, in the third phase of the migration the old models will be removed. + + For more context: https://github.com/apache/superset/issues/14909 """ - db.session.execute(update(SqlaTable).where(SqlaTable.id == obj.table.id)) + inspector = inspect(target) + session = inspector.session + + # double-check that ``UPDATE``s are actually pending (this method is called even + # for instances that have no net changes to their column-based attributes) + if not session.is_modified(target, include_collections=True): + return + + # set permissions + security_manager.set_perm(mapper, connection, target) + + dataset = ( + session.query(NewDataset).filter_by(sqlatable_id=target.id).one_or_none() + ) + if not dataset: + return + + # get DB-specific conditional quoter for expressions that point to columns or + # table names + database = ( + target.database + or session.query(Database).filter_by(id=target.database_id).one() + ) + engine = database.get_sqla_engine(schema=target.schema) + conditional_quote = engine.dialect.identifier_preparer.quote + + # update columns + if inspector.attrs.columns.history.has_changes(): + # handle deleted columns + if inspector.attrs.columns.history.deleted: + column_names = { + column.column_name + for column in inspector.attrs.columns.history.deleted + } + dataset.columns = [ + column + for column in dataset.columns + if column.name not in column_names + ] + + # handle inserted columns + for column in inspector.attrs.columns.history.added: + # ``is_active`` might be ``None``, but it defaults to ``True``. + if column.is_active is False: + continue + + extra_json = json.loads(column.extra or "{}") + for attr in { + "groupby", + "filterable", + "verbose_name", + "python_date_format", + }: + value = getattr(column, attr) + if value: + extra_json[attr] = value + + dataset.columns.append( + NewColumn( + name=column.column_name, + type=column.type or "Unknown", + expression=column.expression + or conditional_quote(column.column_name), + description=column.description, + is_temporal=column.is_dttm, + is_aggregation=False, + is_physical=column.expression is None, + is_spatial=False, + is_partition=False, + is_increase_desired=True, + extra_json=json.dumps(extra_json) if extra_json else None, + is_managed_externally=target.is_managed_externally, + external_url=target.external_url, + ) + ) + + # update metrics + if inspector.attrs.metrics.history.has_changes(): + # handle deleted metrics + if inspector.attrs.metrics.history.deleted: + column_names = { + metric.metric_name + for metric in inspector.attrs.metrics.history.deleted + } + dataset.columns = [ + column + for column in dataset.columns + if column.name not in column_names + ] + + # handle inserted metrics + for metric in inspector.attrs.metrics.history.added: + extra_json = json.loads(metric.extra or "{}") + for attr in {"verbose_name", "metric_type", "d3format"}: + value = getattr(metric, attr) + if value: + extra_json[attr] = value + + is_additive = ( + metric.metric_type + and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES + ) + + dataset.columns.append( + NewColumn( + name=metric.metric_name, + type="Unknown", + expression=metric.expression, + warning_text=metric.warning_text, + description=metric.description, + is_aggregation=True, + is_additive=is_additive, + is_physical=False, + is_spatial=False, + is_partition=False, + is_increase_desired=True, + extra_json=json.dumps(extra_json) if extra_json else None, + is_managed_externally=target.is_managed_externally, + external_url=target.external_url, + ) + ) + + # physical dataset + if target.sql is None: + physical_columns = [ + column for column in dataset.columns if column.is_physical + ] + + # if the table name changed we should create a new table instance, instead + # of reusing the original one + if ( + inspector.attrs.table_name.history.has_changes() + or inspector.attrs.schema.history.has_changes() + or inspector.attrs.database_id.history.has_changes() + ): + # does the dataset point to an existing table? + table = ( + session.query(NewTable) + .filter_by( + database_id=target.database_id, + schema=target.schema, + name=target.table_name, + ) + .first() + ) + if not table: + # create new columns + physical_columns = [ + clone_model(column, ignore=["uuid"]) + for column in physical_columns + ] + + # create new table + table = NewTable( + name=target.table_name, + schema=target.schema, + catalog=None, + database_id=target.database_id, + columns=physical_columns, + is_managed_externally=target.is_managed_externally, + external_url=target.external_url, + ) + dataset.tables = [table] + elif dataset.tables: + table = dataset.tables[0] + table.columns = physical_columns + + # virtual dataset + else: + # mark all columns as virtual (not physical) + for column in dataset.columns: + column.is_physical = False + + # update referenced tables if SQL changed + if inspector.attrs.sql.history.has_changes(): + parsed = ParsedQuery(target.sql) + referenced_tables = parsed.tables + + predicate = or_( + *[ + and_( + NewTable.schema == (table.schema or target.schema), + NewTable.name == table.table, + ) + for table in referenced_tables + ] + ) + dataset.tables = session.query(NewTable).filter(predicate).all() + + # update other attributes + dataset.name = target.table_name + dataset.expression = target.sql or conditional_quote(target.table_name) + dataset.is_physical = target.sql is None -sa.event.listen(SqlaTable, "after_insert", security_manager.set_perm) -sa.event.listen(SqlaTable, "after_update", security_manager.set_perm) sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update) +sa.event.listen(SqlaTable, "after_insert", SqlaTable.after_insert) +sa.event.listen(SqlaTable, "after_delete", SqlaTable.after_delete) +sa.event.listen(SqlaTable, "after_update", SqlaTable.after_update) sa.event.listen(SqlMetric, "after_update", SqlaTable.update_table) sa.event.listen(TableColumn, "after_update", SqlaTable.update_table) diff --git a/superset/datasets/models.py b/superset/datasets/models.py new file mode 100644 index 0000000000000..56a6fbf4000e3 --- /dev/null +++ b/superset/datasets/models.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Dataset model. + +This model was introduced in SIP-68 (https://github.com/apache/superset/issues/14909), +and represents a "dataset" -- either a physical table or a virtual. In addition to a +dataset, new models for columns, metrics, and tables were also introduced. + +These models are not fully implemented, and shouldn't be used yet. +""" + +from typing import List + +import sqlalchemy as sa +from flask_appbuilder import Model +from sqlalchemy.orm import relationship + +from superset.columns.models import Column +from superset.models.helpers import ( + AuditMixinNullable, + ExtraJSONMixin, + ImportExportMixin, +) +from superset.tables.models import Table + +column_association_table = sa.Table( + "sl_dataset_columns", + Model.metadata, # pylint: disable=no-member + sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id")), + sa.Column("column_id", sa.ForeignKey("sl_columns.id")), +) + +table_association_table = sa.Table( + "sl_dataset_tables", + Model.metadata, # pylint: disable=no-member + sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id")), + sa.Column("table_id", sa.ForeignKey("sl_tables.id")), +) + + +class Dataset(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): + """ + A table/view in a database. + """ + + __tablename__ = "sl_datasets" + + id = sa.Column(sa.Integer, primary_key=True) + + # A temporary column, used for shadow writing to the new model. Once the ``SqlaTable`` + # model has been deleted this column can be removed. + sqlatable_id = sa.Column(sa.Integer, nullable=True, unique=True) + + # We use ``sa.Text`` for these attributes because (1) in modern databases the + # performance is the same as ``VARCHAR``[1] and (2) because some table names can be + # **really** long (eg, Google Sheets URLs). + # + # [1] https://www.postgresql.org/docs/9.1/datatype-character.html + name = sa.Column(sa.Text) + + expression = sa.Column(sa.Text) + + # n:n relationship + tables: List[Table] = relationship("Table", secondary=table_association_table) + + # The relationship between datasets and columns is 1:n, but we use a many-to-many + # association to differentiate between the relationship between tables and columns. + columns: List[Column] = relationship( + "Column", secondary=column_association_table, cascade="all, delete" + ) + + # Does the dataset point directly to a ``Table``? + is_physical = sa.Column(sa.Boolean, default=False) + + # Column is managed externally and should be read-only inside Superset + is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) + external_url = sa.Column(sa.Text, nullable=True) diff --git a/superset/datasets/schemas.py b/superset/datasets/schemas.py index 06b13a0e121ef..775798d274fa7 100644 --- a/superset/datasets/schemas.py +++ b/superset/datasets/schemas.py @@ -21,6 +21,9 @@ from flask_babel import lazy_gettext as _ from marshmallow import fields, pre_load, Schema, ValidationError from marshmallow.validate import Length +from marshmallow_sqlalchemy import SQLAlchemyAutoSchema + +from superset.datasets.models import Dataset get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}} get_export_ids_schema = {"type": "array", "items": {"type": "integer"}} @@ -209,3 +212,14 @@ def fix_extra(self, data: Dict[str, Any], **kwargs: Any) -> Dict[str, Any]: version = fields.String(required=True) database_uuid = fields.UUID(required=True) data = fields.URL() + + +class DatasetSchema(SQLAlchemyAutoSchema): + """ + Schema for the ``Dataset`` model. + """ + + class Meta: # pylint: disable=too-few-public-methods + model = Dataset + load_instance = True + include_relationships = True diff --git a/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py b/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py new file mode 100644 index 0000000000000..e6d4537272ff3 --- /dev/null +++ b/superset/migrations/versions/b8d3a24d9131_new_dataset_models.py @@ -0,0 +1,598 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=too-few-public-methods +"""New dataset models + +Revision ID: b8d3a24d9131 +Revises: 5afbb1a5849b +Create Date: 2021-11-11 16:41:53.266965 + +""" + +import json +from typing import Any, Dict, List, Optional, Type +from uuid import uuid4 + +import sqlalchemy as sa +from alembic import op +from sqlalchemy import and_, inspect, or_ +from sqlalchemy.engine import create_engine, Engine +from sqlalchemy.engine.url import make_url, URL +from sqlalchemy.exc import ArgumentError +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import backref, relationship, Session +from sqlalchemy.schema import UniqueConstraint +from sqlalchemy_utils import UUIDType + +from superset import app, db, db_engine_specs +from superset.connectors.sqla.models import ADDITIVE_METRIC_TYPES +from superset.extensions import encrypted_field_factory, security_manager +from superset.sql_parse import ParsedQuery +from superset.utils.memoized import memoized + +# revision identifiers, used by Alembic. +revision = "b8d3a24d9131" +down_revision = "5afbb1a5849b" + +Base = declarative_base() +custom_password_store = app.config["SQLALCHEMY_CUSTOM_PASSWORD_STORE"] +DB_CONNECTION_MUTATOR = app.config["DB_CONNECTION_MUTATOR"] + + +class Database(Base): + + __tablename__ = "dbs" + __table_args__ = (UniqueConstraint("database_name"),) + + id = sa.Column(sa.Integer, primary_key=True) + database_name = sa.Column(sa.String(250), unique=True, nullable=False) + sqlalchemy_uri = sa.Column(sa.String(1024), nullable=False) + password = sa.Column(encrypted_field_factory.create(sa.String(1024))) + impersonate_user = sa.Column(sa.Boolean, default=False) + encrypted_extra = sa.Column(encrypted_field_factory.create(sa.Text), nullable=True) + extra = sa.Column( + sa.Text, + default=json.dumps( + dict( + metadata_params={}, + engine_params={}, + metadata_cache_timeout={}, + schemas_allowed_for_file_upload=[], + ) + ), + ) + server_cert = sa.Column(encrypted_field_factory.create(sa.Text), nullable=True) + + @property + def sqlalchemy_uri_decrypted(self) -> str: + try: + url = make_url(self.sqlalchemy_uri) + except (ArgumentError, ValueError): + return "dialect://invalid_uri" + if custom_password_store: + url.password = custom_password_store(url) + else: + url.password = self.password + return str(url) + + @property + def backend(self) -> str: + sqlalchemy_url = make_url(self.sqlalchemy_uri_decrypted) + return sqlalchemy_url.get_backend_name() # pylint: disable=no-member + + @classmethod + @memoized + def get_db_engine_spec_for_backend( + cls, backend: str + ) -> Type[db_engine_specs.BaseEngineSpec]: + engines = db_engine_specs.get_engine_specs() + return engines.get(backend, db_engine_specs.BaseEngineSpec) + + @property + def db_engine_spec(self) -> Type[db_engine_specs.BaseEngineSpec]: + return self.get_db_engine_spec_for_backend(self.backend) + + def get_extra(self) -> Dict[str, Any]: + return self.db_engine_spec.get_extra_params(self) + + def get_effective_user( + self, object_url: URL, user_name: Optional[str] = None, + ) -> Optional[str]: + effective_username = None + if self.impersonate_user: + effective_username = object_url.username + if user_name: + effective_username = user_name + + return effective_username + + def get_encrypted_extra(self) -> Dict[str, Any]: + return json.loads(self.encrypted_extra) if self.encrypted_extra else {} + + @memoized(watch=("impersonate_user", "sqlalchemy_uri_decrypted", "extra")) + def get_sqla_engine(self, schema: Optional[str] = None) -> Engine: + extra = self.get_extra() + sqlalchemy_url = make_url(self.sqlalchemy_uri_decrypted) + self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema) + effective_username = self.get_effective_user(sqlalchemy_url, "admin") + # If using MySQL or Presto for example, will set url.username + self.db_engine_spec.modify_url_for_impersonation( + sqlalchemy_url, self.impersonate_user, effective_username + ) + + params = extra.get("engine_params", {}) + connect_args = params.get("connect_args", {}) + if self.impersonate_user: + self.db_engine_spec.update_impersonation_config( + connect_args, str(sqlalchemy_url), effective_username + ) + + if connect_args: + params["connect_args"] = connect_args + + params.update(self.get_encrypted_extra()) + + if DB_CONNECTION_MUTATOR: + sqlalchemy_url, params = DB_CONNECTION_MUTATOR( + sqlalchemy_url, + params, + effective_username, + security_manager, + "migration", + ) + + return create_engine(sqlalchemy_url, **params) + + +class TableColumn(Base): + + __tablename__ = "table_columns" + __table_args__ = (UniqueConstraint("table_id", "column_name"),) + + id = sa.Column(sa.Integer, primary_key=True) + table_id = sa.Column(sa.Integer, sa.ForeignKey("tables.id")) + is_active = sa.Column(sa.Boolean, default=True) + extra = sa.Column(sa.Text) + column_name = sa.Column(sa.String(255), nullable=False) + type = sa.Column(sa.String(32)) + expression = sa.Column(sa.Text) + description = sa.Column(sa.Text) + is_dttm = sa.Column(sa.Boolean, default=False) + filterable = sa.Column(sa.Boolean, default=True) + groupby = sa.Column(sa.Boolean, default=True) + verbose_name = sa.Column(sa.String(1024)) + python_date_format = sa.Column(sa.String(255)) + + +class SqlMetric(Base): + + __tablename__ = "sql_metrics" + __table_args__ = (UniqueConstraint("table_id", "metric_name"),) + + id = sa.Column(sa.Integer, primary_key=True) + table_id = sa.Column(sa.Integer, sa.ForeignKey("tables.id")) + extra = sa.Column(sa.Text) + metric_type = sa.Column(sa.String(32)) + metric_name = sa.Column(sa.String(255), nullable=False) + expression = sa.Column(sa.Text, nullable=False) + warning_text = sa.Column(sa.Text) + description = sa.Column(sa.Text) + d3format = sa.Column(sa.String(128)) + verbose_name = sa.Column(sa.String(1024)) + + +class SqlaTable(Base): + + __tablename__ = "tables" + __table_args__ = (UniqueConstraint("database_id", "schema", "table_name"),) + + def fetch_columns_and_metrics(self, session: Session) -> None: + self.columns = session.query(TableColumn).filter( + TableColumn.table_id == self.id + ) + self.metrics = session.query(SqlMetric).filter(TableColumn.table_id == self.id) + + id = sa.Column(sa.Integer, primary_key=True) + columns: List[TableColumn] = [] + column_class = TableColumn + metrics: List[SqlMetric] = [] + metric_class = SqlMetric + + database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False) + database: Database = relationship( + "Database", + backref=backref("tables", cascade="all, delete-orphan"), + foreign_keys=[database_id], + ) + schema = sa.Column(sa.String(255)) + table_name = sa.Column(sa.String(250), nullable=False) + sql = sa.Column(sa.Text) + is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) + external_url = sa.Column(sa.Text, nullable=True) + + +table_column_association_table = sa.Table( + "sl_table_columns", + Base.metadata, + sa.Column("table_id", sa.ForeignKey("sl_tables.id")), + sa.Column("column_id", sa.ForeignKey("sl_columns.id")), +) + +dataset_column_association_table = sa.Table( + "sl_dataset_columns", + Base.metadata, + sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id")), + sa.Column("column_id", sa.ForeignKey("sl_columns.id")), +) + +dataset_table_association_table = sa.Table( + "sl_dataset_tables", + Base.metadata, + sa.Column("dataset_id", sa.ForeignKey("sl_datasets.id")), + sa.Column("table_id", sa.ForeignKey("sl_tables.id")), +) + + +class NewColumn(Base): + + __tablename__ = "sl_columns" + + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Text) + type = sa.Column(sa.Text) + expression = sa.Column(sa.Text) + is_physical = sa.Column(sa.Boolean, default=True) + description = sa.Column(sa.Text) + warning_text = sa.Column(sa.Text) + is_temporal = sa.Column(sa.Boolean, default=False) + is_aggregation = sa.Column(sa.Boolean, default=False) + is_additive = sa.Column(sa.Boolean, default=False) + is_spatial = sa.Column(sa.Boolean, default=False) + is_partition = sa.Column(sa.Boolean, default=False) + is_increase_desired = sa.Column(sa.Boolean, default=True) + is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) + external_url = sa.Column(sa.Text, nullable=True) + extra_json = sa.Column(sa.Text, default="{}") + + +class NewTable(Base): + + __tablename__ = "sl_tables" + __table_args__ = (UniqueConstraint("database_id", "catalog", "schema", "name"),) + + id = sa.Column(sa.Integer, primary_key=True) + name = sa.Column(sa.Text) + schema = sa.Column(sa.Text) + catalog = sa.Column(sa.Text) + database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False) + database: Database = relationship( + "Database", + backref=backref("new_tables", cascade="all, delete-orphan"), + foreign_keys=[database_id], + ) + columns: List[NewColumn] = relationship( + "NewColumn", secondary=table_column_association_table, cascade="all, delete" + ) + is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) + external_url = sa.Column(sa.Text, nullable=True) + + +class NewDataset(Base): + + __tablename__ = "sl_datasets" + + id = sa.Column(sa.Integer, primary_key=True) + sqlatable_id = sa.Column(sa.Integer, nullable=True, unique=True) + name = sa.Column(sa.Text) + expression = sa.Column(sa.Text) + tables: List[NewTable] = relationship( + "NewTable", secondary=dataset_table_association_table + ) + columns: List[NewColumn] = relationship( + "NewColumn", secondary=dataset_column_association_table, cascade="all, delete" + ) + is_physical = sa.Column(sa.Boolean, default=False) + is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) + external_url = sa.Column(sa.Text, nullable=True) + + +def after_insert(target: SqlaTable) -> None: # pylint: disable=too-many-locals + """ + Copy old datasets to the new models. + """ + session = inspect(target).session + + # get DB-specific conditional quoter for expressions that point to columns or + # table names + database = ( + target.database + or session.query(Database).filter_by(id=target.database_id).one() + ) + engine = database.get_sqla_engine(schema=target.schema) + conditional_quote = engine.dialect.identifier_preparer.quote + + # create columns + columns = [] + for column in target.columns: + # ``is_active`` might be ``None`` at this point, but it defaults to ``True``. + if column.is_active is False: + continue + + extra_json = json.loads(column.extra or "{}") + for attr in {"groupby", "filterable", "verbose_name", "python_date_format"}: + value = getattr(column, attr) + if value: + extra_json[attr] = value + + columns.append( + NewColumn( + name=column.column_name, + type=column.type or "Unknown", + expression=column.expression or conditional_quote(column.column_name), + description=column.description, + is_temporal=column.is_dttm, + is_aggregation=False, + is_physical=column.expression is None or column.expression == "", + is_spatial=False, + is_partition=False, + is_increase_desired=True, + extra_json=json.dumps(extra_json) if extra_json else None, + is_managed_externally=target.is_managed_externally, + external_url=target.external_url, + ), + ) + + # create metrics + for metric in target.metrics: + extra_json = json.loads(metric.extra or "{}") + for attr in {"verbose_name", "metric_type", "d3format"}: + value = getattr(metric, attr) + if value: + extra_json[attr] = value + + is_additive = ( + metric.metric_type and metric.metric_type.lower() in ADDITIVE_METRIC_TYPES + ) + + columns.append( + NewColumn( + name=metric.metric_name, + type="Unknown", # figuring this out would require a type inferrer + expression=metric.expression, + warning_text=metric.warning_text, + description=metric.description, + is_aggregation=True, + is_additive=is_additive, + is_physical=False, + is_spatial=False, + is_partition=False, + is_increase_desired=True, + extra_json=json.dumps(extra_json) if extra_json else None, + is_managed_externally=target.is_managed_externally, + external_url=target.external_url, + ), + ) + + # physical dataset + tables = [] + if target.sql is None: + physical_columns = [column for column in columns if column.is_physical] + + # create table + table = NewTable( + name=target.table_name, + schema=target.schema, + catalog=None, # currently not supported + database_id=target.database_id, + columns=physical_columns, + is_managed_externally=target.is_managed_externally, + external_url=target.external_url, + ) + tables.append(table) + + # virtual dataset + else: + # mark all columns as virtual (not physical) + for column in columns: + column.is_physical = False + + # find referenced tables + parsed = ParsedQuery(target.sql) + referenced_tables = parsed.tables + + # predicate for finding the referenced tables + predicate = or_( + *[ + and_( + NewTable.schema == (table.schema or target.schema), + NewTable.name == table.table, + ) + for table in referenced_tables + ] + ) + tables = session.query(NewTable).filter(predicate).all() + + # create the new dataset + dataset = NewDataset( + sqlatable_id=target.id, + name=target.table_name, + expression=target.sql or conditional_quote(target.table_name), + tables=tables, + columns=columns, + is_physical=target.sql is None, + is_managed_externally=target.is_managed_externally, + external_url=target.external_url, + ) + session.add(dataset) + + +def upgrade(): + # Create tables for the new models. + op.create_table( + "sl_columns", + # AuditMixinNullable + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + # ExtraJSONMixin + sa.Column("extra_json", sa.Text(), nullable=True), + # ImportExportMixin + sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), + # Column + sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column("name", sa.TEXT(), nullable=False), + sa.Column("type", sa.TEXT(), nullable=False), + sa.Column("expression", sa.TEXT(), nullable=False), + sa.Column("is_physical", sa.BOOLEAN(), nullable=False, default=True,), + sa.Column("description", sa.TEXT(), nullable=True), + sa.Column("warning_text", sa.TEXT(), nullable=True), + sa.Column("unit", sa.TEXT(), nullable=True), + sa.Column("is_temporal", sa.BOOLEAN(), nullable=False), + sa.Column("is_spatial", sa.BOOLEAN(), nullable=False, default=False,), + sa.Column("is_partition", sa.BOOLEAN(), nullable=False, default=False,), + sa.Column("is_aggregation", sa.BOOLEAN(), nullable=False, default=False,), + sa.Column("is_additive", sa.BOOLEAN(), nullable=False, default=False,), + sa.Column("is_increase_desired", sa.BOOLEAN(), nullable=False, default=True,), + sa.Column( + "is_managed_externally", + sa.Boolean(), + nullable=False, + server_default=sa.false(), + ), + sa.Column("external_url", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + with op.batch_alter_table("sl_columns") as batch_op: + batch_op.create_unique_constraint("uq_sl_columns_uuid", ["uuid"]) + + op.create_table( + "sl_tables", + # AuditMixinNullable + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + # ExtraJSONMixin + sa.Column("extra_json", sa.Text(), nullable=True), + # ImportExportMixin + sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), + # Table + sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column("database_id", sa.INTEGER(), autoincrement=False, nullable=False), + sa.Column("catalog", sa.TEXT(), nullable=True), + sa.Column("schema", sa.TEXT(), nullable=True), + sa.Column("name", sa.TEXT(), nullable=False), + sa.Column( + "is_managed_externally", + sa.Boolean(), + nullable=False, + server_default=sa.false(), + ), + sa.Column("external_url", sa.Text(), nullable=True), + sa.ForeignKeyConstraint(["database_id"], ["dbs.id"], name="sl_tables_ibfk_1"), + sa.PrimaryKeyConstraint("id"), + ) + with op.batch_alter_table("sl_tables") as batch_op: + batch_op.create_unique_constraint("uq_sl_tables_uuid", ["uuid"]) + + op.create_table( + "sl_table_columns", + sa.Column("table_id", sa.INTEGER(), autoincrement=False, nullable=False), + sa.Column("column_id", sa.INTEGER(), autoincrement=False, nullable=False), + sa.ForeignKeyConstraint( + ["column_id"], ["sl_columns.id"], name="sl_table_columns_ibfk_2" + ), + sa.ForeignKeyConstraint( + ["table_id"], ["sl_tables.id"], name="sl_table_columns_ibfk_1" + ), + ) + + op.create_table( + "sl_datasets", + # AuditMixinNullable + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + # ExtraJSONMixin + sa.Column("extra_json", sa.Text(), nullable=True), + # ImportExportMixin + sa.Column("uuid", UUIDType(binary=True), primary_key=False, default=uuid4), + # Dataset + sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column("sqlatable_id", sa.INTEGER(), nullable=True), + sa.Column("name", sa.TEXT(), nullable=False), + sa.Column("expression", sa.TEXT(), nullable=False), + sa.Column("is_physical", sa.BOOLEAN(), nullable=False, default=False,), + sa.Column( + "is_managed_externally", + sa.Boolean(), + nullable=False, + server_default=sa.false(), + ), + sa.Column("external_url", sa.Text(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + with op.batch_alter_table("sl_datasets") as batch_op: + batch_op.create_unique_constraint("uq_sl_datasets_uuid", ["uuid"]) + batch_op.create_unique_constraint( + "uq_sl_datasets_sqlatable_id", ["sqlatable_id"] + ) + + op.create_table( + "sl_dataset_columns", + sa.Column("dataset_id", sa.INTEGER(), autoincrement=False, nullable=False), + sa.Column("column_id", sa.INTEGER(), autoincrement=False, nullable=False), + sa.ForeignKeyConstraint( + ["column_id"], ["sl_columns.id"], name="sl_dataset_columns_ibfk_2" + ), + sa.ForeignKeyConstraint( + ["dataset_id"], ["sl_datasets.id"], name="sl_dataset_columns_ibfk_1" + ), + ) + + op.create_table( + "sl_dataset_tables", + sa.Column("dataset_id", sa.INTEGER(), autoincrement=False, nullable=False), + sa.Column("table_id", sa.INTEGER(), autoincrement=False, nullable=False), + sa.ForeignKeyConstraint( + ["dataset_id"], ["sl_datasets.id"], name="sl_dataset_tables_ibfk_1" + ), + sa.ForeignKeyConstraint( + ["table_id"], ["sl_tables.id"], name="sl_dataset_tables_ibfk_2" + ), + ) + + # migrate existing datasets to the new models + bind = op.get_bind() + session = db.Session(bind=bind) # pylint: disable=no-member + + datasets = session.query(SqlaTable).all() + for dataset in datasets: + dataset.fetch_columns_and_metrics(session) + after_insert(target=dataset) + + +def downgrade(): + op.drop_table("sl_dataset_columns") + op.drop_table("sl_dataset_tables") + op.drop_table("sl_datasets") + op.drop_table("sl_table_columns") + op.drop_table("sl_tables") + op.drop_table("sl_columns") diff --git a/superset/models/helpers.py b/superset/models/helpers.py index be02997a21633..f1adadfbc453f 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -29,6 +29,7 @@ import sqlalchemy as sa import yaml from flask import escape, g, Markup +from flask_appbuilder import Model from flask_appbuilder.models.decorators import renders from flask_appbuilder.models.mixins import AuditMixin from flask_appbuilder.security.sqla.models import User @@ -510,3 +511,22 @@ def certification_details(self) -> Optional[str]: @property def warning_markdown(self) -> Optional[str]: return self.get_extra_dict().get("warning_markdown") + + +def clone_model( + target: Model, ignore: Optional[List[str]] = None, **kwargs: Any +) -> Model: + """ + Clone a SQLAlchemy model. + """ + ignore = ignore or [] + + table = target.__table__ + data = { + attr: getattr(target, attr) + for attr in table.columns.keys() + if attr not in table.primary_key.columns.keys() and attr not in ignore + } + data.update(kwargs) + + return target.__class__(**data) diff --git a/superset/security/manager.py b/superset/security/manager.py index 91b203e83f774..d3cae1f4d01b1 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -950,6 +950,7 @@ def set_perm( # pylint: disable=unused-argument .where(link_table.c.id == target.id) .values(perm=target.get_perm()) ) + target.perm = target.get_perm() if ( hasattr(target, "schema_perm") @@ -960,6 +961,7 @@ def set_perm( # pylint: disable=unused-argument .where(link_table.c.id == target.id) .values(schema_perm=target.get_schema_perm()) ) + target.schema_perm = target.get_schema_perm() pvm_names = [] if target.__tablename__ in {"dbs", "clusters"}: diff --git a/superset/tables/__init__.py b/superset/tables/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/superset/tables/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/tables/models.py b/superset/tables/models.py new file mode 100644 index 0000000000000..e2489445c686b --- /dev/null +++ b/superset/tables/models.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Table model. + +This model was introduced in SIP-68 (https://github.com/apache/superset/issues/14909), +and represents a "table" in a given database -- either a physical table or a view. In +addition to a table, new models for columns, metrics, and datasets were also introduced. + +These models are not fully implemented, and shouldn't be used yet. +""" + +from typing import List + +import sqlalchemy as sa +from flask_appbuilder import Model +from sqlalchemy.orm import backref, relationship +from sqlalchemy.schema import UniqueConstraint + +from superset.columns.models import Column +from superset.models.core import Database +from superset.models.helpers import ( + AuditMixinNullable, + ExtraJSONMixin, + ImportExportMixin, +) + +association_table = sa.Table( + "sl_table_columns", + Model.metadata, # pylint: disable=no-member + sa.Column("table_id", sa.ForeignKey("sl_tables.id")), + sa.Column("column_id", sa.ForeignKey("sl_columns.id")), +) + + +class Table(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin): + """ + A table/view in a database. + """ + + __tablename__ = "sl_tables" + + # Note this uniqueness constraint is not part of the physical schema, i.e., it does + # not exist in the migrations. The reason it does not physically exist is MySQL, + # PostgreSQL, etc. have a different interpretation of uniqueness when it comes to NULL + # which is problematic given the catalog and schema are optional. + __table_args__ = (UniqueConstraint("database_id", "catalog", "schema", "name"),) + + id = sa.Column(sa.Integer, primary_key=True) + + database_id = sa.Column(sa.Integer, sa.ForeignKey("dbs.id"), nullable=False) + database: Database = relationship( + "Database", + # TODO (betodealmeida): rename the backref to ``tables`` once we get rid of the + # old models. + backref=backref("new_tables", cascade="all, delete-orphan"), + foreign_keys=[database_id], + ) + + # We use ``sa.Text`` for these attributes because (1) in modern databases the + # performance is the same as ``VARCHAR``[1] and (2) because some table names can be + # **really** long (eg, Google Sheets URLs). + # + # [1] https://www.postgresql.org/docs/9.1/datatype-character.html + catalog = sa.Column(sa.Text) + schema = sa.Column(sa.Text) + name = sa.Column(sa.Text) + + # The relationship between tables and columns is 1:n, but we use a many-to-many + # association to differentiate between the relationship between datasets and + # columns. + columns: List[Column] = relationship( + "Column", secondary=association_table, cascade="all, delete" + ) + + # Column is managed externally and should be read-only inside Superset + is_managed_externally = sa.Column(sa.Boolean, nullable=False, default=False) + external_url = sa.Column(sa.Text, nullable=True) diff --git a/superset/tables/schemas.py b/superset/tables/schemas.py new file mode 100644 index 0000000000000..701a1359ba003 --- /dev/null +++ b/superset/tables/schemas.py @@ -0,0 +1,40 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Schema for table model. + +This model was introduced in SIP-68 (https://github.com/apache/superset/issues/14909), +and represents a "table" in a given database -- either a physical table or a view. In +addition to a table, new models for columns, metrics, and datasets were also introduced. + +These models are not fully implemented, and shouldn't be used yet. +""" + +from marshmallow_sqlalchemy import SQLAlchemyAutoSchema + +from superset.tables.models import Table + + +class TableSchema(SQLAlchemyAutoSchema): + """ + Schema for the ``Table`` model. + """ + + class Meta: # pylint: disable=too-few-public-methods + model = Table + load_instance = True + include_relationships = True diff --git a/tests/integration_tests/dashboards/filter_sets/conftest.py b/tests/integration_tests/dashboards/filter_sets/conftest.py index d07f869f6a551..b7a28273b0a7e 100644 --- a/tests/integration_tests/dashboards/filter_sets/conftest.py +++ b/tests/integration_tests/dashboards/filter_sets/conftest.py @@ -21,7 +21,7 @@ import pytest -from superset import security_manager as sm +from superset import db, security_manager as sm from superset.dashboards.filter_sets.consts import ( DESCRIPTION_FIELD, JSON_METADATA_FIELD, @@ -66,20 +66,6 @@ security_manager: BaseSecurityManager = sm -# @pytest.fixture(autouse=True, scope="session") -# def setup_sample_data() -> Any: -# pass - - -@pytest.fixture(autouse=True) -def expire_on_commit_true() -> Generator[None, None, None]: - ctx: AppContext - with app.app_context() as ctx: - ctx.app.appbuilder.get_session.configure(expire_on_commit=False) - yield - ctx.app.appbuilder.get_session.configure(expire_on_commit=True) - - @pytest.fixture(autouse=True, scope="module") def test_users() -> Generator[Dict[str, int], None, None]: usernames = [ @@ -92,17 +78,14 @@ def test_users() -> Generator[Dict[str, int], None, None]: filter_set_role = build_filter_set_role() admin_role: Role = security_manager.find_role("Admin") usernames_to_ids = create_test_users(admin_role, filter_set_role, usernames) - yield usernames_to_ids - ctx: AppContext - delete_users(usernames_to_ids) + yield usernames_to_ids + delete_users(usernames_to_ids) def delete_users(usernames_to_ids: Dict[str, int]) -> None: - with app.app_context() as ctx: - session: Session = ctx.app.appbuilder.get_session - for username in usernames_to_ids.keys(): - session.delete(security_manager.find_user(username)) - session.commit() + for username in usernames_to_ids.keys(): + db.session.delete(security_manager.find_user(username)) + db.session.commit() def create_test_users( @@ -150,106 +133,86 @@ def client() -> Generator[FlaskClient[Any], None, None]: @pytest.fixture -def dashboard() -> Generator[Dashboard, None, None]: - dashboard: Dashboard - slice_: Slice - datasource: SqlaTable - database: Database - session: Session - try: - with app.app_context() as ctx: - dashboard_owner_user = security_manager.find_user(DASHBOARD_OWNER_USERNAME) - database = create_database("test_database_filter_sets") - datasource = create_datasource_table( - name="test_datasource", database=database, owners=[dashboard_owner_user] - ) - slice_ = create_slice( - datasource=datasource, name="test_slice", owners=[dashboard_owner_user] - ) - dashboard = create_dashboard( - dashboard_title="test_dashboard", - published=True, - slices=[slice_], - owners=[dashboard_owner_user], - ) - session = ctx.app.appbuilder.get_session - session.add(dashboard) - session.commit() - yield dashboard - except Exception as ex: - print(str(ex)) - finally: - with app.app_context() as ctx: - session = ctx.app.appbuilder.get_session - try: - dashboard.owners = [] - slice_.owners = [] - datasource.owners = [] - session.merge(dashboard) - session.merge(slice_) - session.merge(datasource) - session.commit() - session.delete(dashboard) - session.delete(slice_) - session.delete(datasource) - session.delete(database) - session.commit() - except Exception as ex: - print(str(ex)) +def dashboard(app_context) -> Generator[Dashboard, None, None]: + dashboard_owner_user = security_manager.find_user(DASHBOARD_OWNER_USERNAME) + database = create_database("test_database_filter_sets") + datasource = create_datasource_table( + name="test_datasource", database=database, owners=[dashboard_owner_user] + ) + slice_ = create_slice( + datasource=datasource, name="test_slice", owners=[dashboard_owner_user] + ) + dashboard = create_dashboard( + dashboard_title="test_dashboard", + published=True, + slices=[slice_], + owners=[dashboard_owner_user], + ) + db.session.add(dashboard) + db.session.commit() + + yield dashboard + + db.session.delete(dashboard) + db.session.delete(slice_) + db.session.delete(datasource) + db.session.delete(database) + db.session.commit() @pytest.fixture -def dashboard_id(dashboard) -> int: - return dashboard.id +def dashboard_id(dashboard: Dashboard) -> Generator[int, None, None]: + yield dashboard.id @pytest.fixture def filtersets( dashboard_id: int, test_users: Dict[str, int], dumped_valid_json_metadata: str ) -> Generator[Dict[str, List[FilterSet]], None, None]: - try: - with app.app_context() as ctx: - session: Session = ctx.app.appbuilder.get_session - first_filter_set = FilterSet( - name="filter_set_1_of_" + str(dashboard_id), - dashboard_id=dashboard_id, - json_metadata=dumped_valid_json_metadata, - owner_id=dashboard_id, - owner_type="Dashboard", - ) - second_filter_set = FilterSet( - name="filter_set_2_of_" + str(dashboard_id), - json_metadata=dumped_valid_json_metadata, - dashboard_id=dashboard_id, - owner_id=dashboard_id, - owner_type="Dashboard", - ) - third_filter_set = FilterSet( - name="filter_set_3_of_" + str(dashboard_id), - json_metadata=dumped_valid_json_metadata, - dashboard_id=dashboard_id, - owner_id=test_users[FILTER_SET_OWNER_USERNAME], - owner_type="User", - ) - forth_filter_set = FilterSet( - name="filter_set_4_of_" + str(dashboard_id), - json_metadata=dumped_valid_json_metadata, - dashboard_id=dashboard_id, - owner_id=test_users[FILTER_SET_OWNER_USERNAME], - owner_type="User", - ) - session.add(first_filter_set) - session.add(second_filter_set) - session.add(third_filter_set) - session.add(forth_filter_set) - session.commit() - yv = { - "Dashboard": [first_filter_set, second_filter_set], - FILTER_SET_OWNER_USERNAME: [third_filter_set, forth_filter_set], - } - yield yv - except Exception as ex: - print(str(ex)) + first_filter_set = FilterSet( + name="filter_set_1_of_" + str(dashboard_id), + dashboard_id=dashboard_id, + json_metadata=dumped_valid_json_metadata, + owner_id=dashboard_id, + owner_type="Dashboard", + ) + second_filter_set = FilterSet( + name="filter_set_2_of_" + str(dashboard_id), + json_metadata=dumped_valid_json_metadata, + dashboard_id=dashboard_id, + owner_id=dashboard_id, + owner_type="Dashboard", + ) + third_filter_set = FilterSet( + name="filter_set_3_of_" + str(dashboard_id), + json_metadata=dumped_valid_json_metadata, + dashboard_id=dashboard_id, + owner_id=test_users[FILTER_SET_OWNER_USERNAME], + owner_type="User", + ) + fourth_filter_set = FilterSet( + name="filter_set_4_of_" + str(dashboard_id), + json_metadata=dumped_valid_json_metadata, + dashboard_id=dashboard_id, + owner_id=test_users[FILTER_SET_OWNER_USERNAME], + owner_type="User", + ) + db.session.add(first_filter_set) + db.session.add(second_filter_set) + db.session.add(third_filter_set) + db.session.add(fourth_filter_set) + db.session.commit() + + yield { + "Dashboard": [first_filter_set, second_filter_set], + FILTER_SET_OWNER_USERNAME: [third_filter_set, fourth_filter_set], + } + + db.session.delete(first_filter_set) + db.session.delete(second_filter_set) + db.session.delete(third_filter_set) + db.session.delete(fourth_filter_set) + db.session.commit() @pytest.fixture @@ -299,8 +262,8 @@ def valid_filter_set_data_for_update( @pytest.fixture -def not_exists_dashboard(dashboard_id: int) -> int: - return dashboard_id + 1 +def not_exists_dashboard_id(dashboard_id: int) -> Generator[int, None, None]: + yield dashboard_id + 1 @pytest.fixture diff --git a/tests/integration_tests/dashboards/filter_sets/create_api_tests.py b/tests/integration_tests/dashboards/filter_sets/create_api_tests.py index cbdaef9b95a01..fcd2923fb8696 100644 --- a/tests/integration_tests/dashboards/filter_sets/create_api_tests.py +++ b/tests/integration_tests/dashboards/filter_sets/create_api_tests.py @@ -94,14 +94,14 @@ def test_with_id_field__400( def test_with_dashboard_not_exists__404( self, - not_exists_dashboard: int, + not_exists_dashboard_id: int, valid_filter_set_data_for_create: Dict[str, Any], client: FlaskClient[Any], ): # act login(client, "admin") response = call_create_filter_set( - client, not_exists_dashboard, valid_filter_set_data_for_create + client, not_exists_dashboard_id, valid_filter_set_data_for_create ) # assert diff --git a/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py b/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py index 34f52011d812b..8e7e0bcb6004e 100644 --- a/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py +++ b/tests/integration_tests/dashboards/filter_sets/delete_api_tests.py @@ -61,7 +61,7 @@ def test_with_dashboard_exists_filterset_not_exists__200( def test_with_dashboard_not_exists_filterset_not_exists__404( self, - not_exists_dashboard: int, + not_exists_dashboard_id: int, filtersets: Dict[str, List[FilterSet]], client: FlaskClient[Any], ): @@ -70,14 +70,14 @@ def test_with_dashboard_not_exists_filterset_not_exists__404( filter_set_id = max(collect_all_ids(filtersets)) + 1 response = call_delete_filter_set( - client, {"id": filter_set_id}, not_exists_dashboard + client, {"id": filter_set_id}, not_exists_dashboard_id ) # assert assert response.status_code == 404 def test_with_dashboard_not_exists_filterset_exists__404( self, - not_exists_dashboard: int, + not_exists_dashboard_id: int, dashboard_based_filter_set_dict: Dict[str, Any], client: FlaskClient[Any], ): @@ -86,7 +86,7 @@ def test_with_dashboard_not_exists_filterset_exists__404( # act response = call_delete_filter_set( - client, dashboard_based_filter_set_dict, not_exists_dashboard + client, dashboard_based_filter_set_dict, not_exists_dashboard_id ) # assert assert response.status_code == 404 diff --git a/tests/integration_tests/dashboards/filter_sets/get_api_tests.py b/tests/integration_tests/dashboards/filter_sets/get_api_tests.py index a1ad55aa2a8bd..0d36a0a593e5b 100644 --- a/tests/integration_tests/dashboards/filter_sets/get_api_tests.py +++ b/tests/integration_tests/dashboards/filter_sets/get_api_tests.py @@ -37,13 +37,13 @@ class TestGetFilterSetsApi: def test_with_dashboard_not_exists__404( - self, not_exists_dashboard: int, client: FlaskClient[Any], + self, not_exists_dashboard_id: int, client: FlaskClient[Any], ): # arrange login(client, "admin") # act - response = call_get_filter_sets(client, not_exists_dashboard) + response = call_get_filter_sets(client, not_exists_dashboard_id) # assert assert response.status_code == 404 diff --git a/tests/integration_tests/dashboards/filter_sets/update_api_tests.py b/tests/integration_tests/dashboards/filter_sets/update_api_tests.py index ccd5ae83afba1..4096e100994f8 100644 --- a/tests/integration_tests/dashboards/filter_sets/update_api_tests.py +++ b/tests/integration_tests/dashboards/filter_sets/update_api_tests.py @@ -85,7 +85,7 @@ def test_with_dashboard_exists_filterset_not_exists__404( def test_with_dashboard_not_exists_filterset_not_exists__404( self, - not_exists_dashboard: int, + not_exists_dashboard_id: int, filtersets: Dict[str, List[FilterSet]], client: FlaskClient[Any], ): @@ -94,14 +94,14 @@ def test_with_dashboard_not_exists_filterset_not_exists__404( filter_set_id = max(collect_all_ids(filtersets)) + 1 response = call_update_filter_set( - client, {"id": filter_set_id}, {}, not_exists_dashboard + client, {"id": filter_set_id}, {}, not_exists_dashboard_id ) # assert assert response.status_code == 404 def test_with_dashboard_not_exists_filterset_exists__404( self, - not_exists_dashboard: int, + not_exists_dashboard_id: int, dashboard_based_filter_set_dict: Dict[str, Any], client: FlaskClient[Any], ): @@ -110,7 +110,7 @@ def test_with_dashboard_not_exists_filterset_exists__404( # act response = call_update_filter_set( - client, dashboard_based_filter_set_dict, {}, not_exists_dashboard + client, dashboard_based_filter_set_dict, {}, not_exists_dashboard_id ) # assert assert response.status_code == 404 diff --git a/tests/integration_tests/dashboards/superset_factory_util.py b/tests/integration_tests/dashboards/superset_factory_util.py index b67c60ca0736f..b160a56a33fbf 100644 --- a/tests/integration_tests/dashboards/superset_factory_util.py +++ b/tests/integration_tests/dashboards/superset_factory_util.py @@ -20,7 +20,7 @@ from flask_appbuilder import Model from flask_appbuilder.security.sqla.models import User -from superset import appbuilder +from superset import db from superset.connectors.sqla.models import SqlaTable, sqlatable_user from superset.models.core import Database from superset.models.dashboard import ( @@ -38,7 +38,7 @@ logger = logging.getLogger(__name__) -session = appbuilder.get_session +session = db.session inserted_dashboards_ids = [] inserted_databases_ids = [] @@ -192,9 +192,11 @@ def delete_all_inserted_objects() -> None: def delete_all_inserted_dashboards(): try: - dashboards_to_delete: List[Dashboard] = session.query(Dashboard).filter( - Dashboard.id.in_(inserted_dashboards_ids) - ).all() + dashboards_to_delete: List[Dashboard] = ( + session.query(Dashboard) + .filter(Dashboard.id.in_(inserted_dashboards_ids)) + .all() + ) for dashboard in dashboards_to_delete: try: delete_dashboard(dashboard, False) @@ -239,9 +241,9 @@ def delete_dashboard_slices_associations(dashboard: Dashboard) -> None: def delete_all_inserted_slices(): try: - slices_to_delete: List[Slice] = session.query(Slice).filter( - Slice.id.in_(inserted_slices_ids) - ).all() + slices_to_delete: List[Slice] = ( + session.query(Slice).filter(Slice.id.in_(inserted_slices_ids)).all() + ) for slice in slices_to_delete: try: delete_slice(slice, False) @@ -270,9 +272,11 @@ def delete_slice_users_associations(slice_: Slice) -> None: def delete_all_inserted_tables(): try: - tables_to_delete: List[SqlaTable] = session.query(SqlaTable).filter( - SqlaTable.id.in_(inserted_sqltables_ids) - ).all() + tables_to_delete: List[SqlaTable] = ( + session.query(SqlaTable) + .filter(SqlaTable.id.in_(inserted_sqltables_ids)) + .all() + ) for table in tables_to_delete: try: delete_sqltable(table, False) @@ -303,9 +307,11 @@ def delete_table_users_associations(table: SqlaTable) -> None: def delete_all_inserted_dbs(): try: - dbs_to_delete: List[Database] = session.query(Database).filter( - Database.id.in_(inserted_databases_ids) - ).all() + dbs_to_delete: List[Database] = ( + session.query(Database) + .filter(Database.id.in_(inserted_databases_ids)) + .all() + ) for db in dbs_to_delete: try: delete_database(db, False) diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index aaf76338ef01a..eeda824500fe6 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -500,6 +500,7 @@ def test_create_dataset_validate_uniqueness(self): "message": {"table_name": ["Dataset energy_usage already exists"]} } + @unittest.skip("test is failing stochastically") def test_create_dataset_same_name_different_schema(self): if backend() == "sqlite": # sqlite doesn't support schemas diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index baa5fe35d7320..7323c39f017c8 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -52,6 +52,10 @@ public_role_like_gamma, public_role_like_test_role, ) +from tests.integration_tests.fixtures.birth_names_dashboard import ( + load_birth_names_dashboard_with_slices, + load_birth_names_data, +) from tests.integration_tests.fixtures.world_bank_dashboard import ( load_world_bank_dashboard_with_slices, load_world_bank_data, @@ -224,7 +228,7 @@ def test_set_perm_sqla_table(self): ) # database change - new_db = Database(sqlalchemy_uri="some_uri", database_name="tmp_db") + new_db = Database(sqlalchemy_uri="sqlite://", database_name="tmp_db") session.add(new_db) stored_table.database = ( session.query(Database).filter_by(database_name="tmp_db").one() @@ -358,9 +362,7 @@ def test_set_perm_druid_cluster(self): def test_set_perm_database(self): session = db.session - database = Database( - database_name="tmp_database", sqlalchemy_uri="sqlite://test" - ) + database = Database(database_name="tmp_database", sqlalchemy_uri="sqlite://") session.add(database) stored_db = ( @@ -411,9 +413,7 @@ def test_hybrid_perm_druid_cluster(self): db.session.commit() def test_hybrid_perm_database(self): - database = Database( - database_name="tmp_database3", sqlalchemy_uri="sqlite://test" - ) + database = Database(database_name="tmp_database3", sqlalchemy_uri="sqlite://") db.session.add(database) @@ -437,9 +437,7 @@ def test_hybrid_perm_database(self): def test_set_perm_slice(self): session = db.session - database = Database( - database_name="tmp_database", sqlalchemy_uri="sqlite://test" - ) + database = Database(database_name="tmp_database", sqlalchemy_uri="sqlite://") table = SqlaTable(table_name="tmp_perm_table", database=database) session.add(database) session.add(table) @@ -573,6 +571,7 @@ def test_gamma_user_schema_access_to_charts(self): ) # wb_health_population slice, has access self.assertNotIn("Girl Name Cloud", data) # birth_names slice, no access + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @pytest.mark.usefixtures("public_role_like_gamma") def test_public_sync_role_data_perms(self): """ diff --git a/tests/unit_tests/columns/__init__.py b/tests/unit_tests/columns/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/unit_tests/columns/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/columns/test_models.py b/tests/unit_tests/columns/test_models.py new file mode 100644 index 0000000000000..36c6b9b4e7301 --- /dev/null +++ b/tests/unit_tests/columns/test_models.py @@ -0,0 +1,53 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel, unused-argument + +from sqlalchemy.orm.session import Session + + +def test_column_model(app_context: None, session: Session) -> None: + """ + Test basic attributes of a ``Column``. + """ + from superset.columns.models import Column + + engine = session.get_bind() + Column.metadata.create_all(engine) # pylint: disable=no-member + + column = Column(name="ds", type="TIMESTAMP", expression="ds",) + + session.add(column) + session.flush() + + assert column.id == 1 + assert column.uuid is not None + + assert column.name == "ds" + assert column.type == "TIMESTAMP" + assert column.expression == "ds" + + # test that default values are set correctly + assert column.description is None + assert column.warning_text is None + assert column.unit is None + assert column.is_temporal is False + assert column.is_spatial is False + assert column.is_partition is False + assert column.is_aggregation is False + assert column.is_additive is False + assert column.is_increase_desired is True diff --git a/tests/unit_tests/datasets/commands/importers/v1/import_test.py b/tests/unit_tests/datasets/commands/importers/v1/import_test.py index e622c55c3bc27..0aa0f67a07690 100644 --- a/tests/unit_tests/datasets/commands/importers/v1/import_test.py +++ b/tests/unit_tests/datasets/commands/importers/v1/import_test.py @@ -22,15 +22,14 @@ from sqlalchemy.orm.session import Session -from superset.datasets.schemas import ImportV1DatasetSchema - -def test_import_(app_context: None, session: Session) -> None: +def test_import_dataset(app_context: None, session: Session) -> None: """ Test importing a dataset. """ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.datasets.commands.importers.v1.utils import import_dataset + from superset.datasets.schemas import ImportV1DatasetSchema from superset.models.core import Database engine = session.get_bind() @@ -120,11 +119,11 @@ def test_import_(app_context: None, session: Session) -> None: assert len(sqla_table.columns) == 1 assert sqla_table.columns[0].column_name == "profit" assert sqla_table.columns[0].verbose_name is None - assert sqla_table.columns[0].is_dttm is False - assert sqla_table.columns[0].is_active is True + assert sqla_table.columns[0].is_dttm is None + assert sqla_table.columns[0].is_active is None assert sqla_table.columns[0].type == "INTEGER" - assert sqla_table.columns[0].groupby is True - assert sqla_table.columns[0].filterable is True + assert sqla_table.columns[0].groupby is None + assert sqla_table.columns[0].filterable is None assert sqla_table.columns[0].expression == "revenue-expenses" assert sqla_table.columns[0].description is None assert sqla_table.columns[0].python_date_format is None @@ -139,6 +138,7 @@ def test_import_column_extra_is_string(app_context: None, session: Session) -> N """ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.datasets.commands.importers.v1.utils import import_dataset + from superset.datasets.schemas import ImportV1DatasetSchema from superset.models.core import Database engine = session.get_bind() diff --git a/tests/unit_tests/datasets/test_models.py b/tests/unit_tests/datasets/test_models.py new file mode 100644 index 0000000000000..eab0a8aa28288 --- /dev/null +++ b/tests/unit_tests/datasets/test_models.py @@ -0,0 +1,1244 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel, unused-argument, unused-import, too-many-locals, invalid-name, too-many-lines + +import json +from datetime import datetime, timezone + +from pytest_mock import MockFixture +from sqlalchemy.orm.session import Session + + +def test_dataset_model(app_context: None, session: Session) -> None: + """ + Test basic attributes of a ``Dataset``. + """ + from superset.columns.models import Column + from superset.datasets.models import Dataset + from superset.models.core import Database + from superset.tables.models import Table + + engine = session.get_bind() + Dataset.metadata.create_all(engine) # pylint: disable=no-member + + table = Table( + name="my_table", + schema="my_schema", + catalog="my_catalog", + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + columns=[ + Column(name="longitude", expression="longitude"), + Column(name="latitude", expression="latitude"), + ], + ) + session.add(table) + session.flush() + + dataset = Dataset( + name="positions", + expression=""" +SELECT array_agg(array[longitude,latitude]) AS position +FROM my_catalog.my_schema.my_table +""", + tables=[table], + columns=[ + Column(name="position", expression="array_agg(array[longitude,latitude])",), + ], + ) + session.add(dataset) + session.flush() + + assert dataset.id == 1 + assert dataset.uuid is not None + + assert dataset.name == "positions" + assert ( + dataset.expression + == """ +SELECT array_agg(array[longitude,latitude]) AS position +FROM my_catalog.my_schema.my_table +""" + ) + + assert [table.name for table in dataset.tables] == ["my_table"] + assert [column.name for column in dataset.columns] == ["position"] + + +def test_cascade_delete_table(app_context: None, session: Session) -> None: + """ + Test that deleting ``Table`` also deletes its columns. + """ + from superset.columns.models import Column + from superset.models.core import Database + from superset.tables.models import Table + + engine = session.get_bind() + Table.metadata.create_all(engine) # pylint: disable=no-member + + table = Table( + name="my_table", + schema="my_schema", + catalog="my_catalog", + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + columns=[ + Column(name="longitude", expression="longitude"), + Column(name="latitude", expression="latitude"), + ], + ) + session.add(table) + session.flush() + + columns = session.query(Column).all() + assert len(columns) == 2 + + session.delete(table) + session.flush() + + # test that columns were deleted + columns = session.query(Column).all() + assert len(columns) == 0 + + +def test_cascade_delete_dataset(app_context: None, session: Session) -> None: + """ + Test that deleting ``Dataset`` also deletes its columns. + """ + from superset.columns.models import Column + from superset.datasets.models import Dataset + from superset.models.core import Database + from superset.tables.models import Table + + engine = session.get_bind() + Dataset.metadata.create_all(engine) # pylint: disable=no-member + + table = Table( + name="my_table", + schema="my_schema", + catalog="my_catalog", + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + columns=[ + Column(name="longitude", expression="longitude"), + Column(name="latitude", expression="latitude"), + ], + ) + session.add(table) + session.flush() + + dataset = Dataset( + name="positions", + expression=""" +SELECT array_agg(array[longitude,latitude]) AS position +FROM my_catalog.my_schema.my_table +""", + tables=[table], + columns=[ + Column(name="position", expression="array_agg(array[longitude,latitude])",), + ], + ) + session.add(dataset) + session.flush() + + columns = session.query(Column).all() + assert len(columns) == 3 + + session.delete(dataset) + session.flush() + + # test that dataset columns were deleted (but not table columns) + columns = session.query(Column).all() + assert len(columns) == 2 + + +def test_dataset_attributes(app_context: None, session: Session) -> None: + """ + Test that checks attributes in the dataset. + + If this check fails it means new attributes were added to ``SqlaTable``, and + ``SqlaTable.after_insert`` should be updated to handle them! + """ + from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn + from superset.models.core import Database + + engine = session.get_bind() + SqlaTable.metadata.create_all(engine) # pylint: disable=no-member + + columns = [ + TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), + TableColumn(column_name="user_id", type="INTEGER"), + TableColumn(column_name="revenue", type="INTEGER"), + TableColumn(column_name="expenses", type="INTEGER"), + TableColumn( + column_name="profit", type="INTEGER", expression="revenue-expenses" + ), + ] + metrics = [ + SqlMetric(metric_name="cnt", expression="COUNT(*)"), + ] + + sqla_table = SqlaTable( + table_name="old_dataset", + columns=columns, + metrics=metrics, + main_dttm_col="ds", + default_endpoint="https://www.youtube.com/watch?v=dQw4w9WgXcQ", # not used + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + offset=-8, + description="This is the description", + is_featured=1, + cache_timeout=3600, + schema="my_schema", + sql=None, + params=json.dumps( + {"remote_id": 64, "database_name": "examples", "import_time": 1606677834,} + ), + perm=None, + filter_select_enabled=1, + fetch_values_predicate="foo IN (1, 2)", + is_sqllab_view=0, # no longer used? + template_params=json.dumps({"answer": "42"}), + schema_perm=None, + extra=json.dumps({"warning_markdown": "*WARNING*"}), + ) + + session.add(sqla_table) + session.flush() + + dataset = session.query(SqlaTable).one() + # If this test fails because attributes changed, make sure to update + # ``SqlaTable.after_insert`` accordingly. + assert sorted(dataset.__dict__.keys()) == [ + "_sa_instance_state", + "cache_timeout", + "changed_by_fk", + "changed_on", + "columns", + "created_by_fk", + "created_on", + "database", + "database_id", + "default_endpoint", + "description", + "external_url", + "extra", + "fetch_values_predicate", + "filter_select_enabled", + "id", + "is_featured", + "is_managed_externally", + "is_sqllab_view", + "main_dttm_col", + "metrics", + "offset", + "params", + "perm", + "schema", + "schema_perm", + "sql", + "table_name", + "template_params", + "uuid", + ] + + +def test_create_physical_sqlatable(app_context: None, session: Session) -> None: + """ + Test shadow write when creating a new ``SqlaTable``. + + When a new physical ``SqlaTable`` is created, new models should also be created for + ``Dataset``, ``Table``, and ``Column``. + """ + from superset.columns.models import Column + from superset.columns.schemas import ColumnSchema + from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn + from superset.datasets.models import Dataset + from superset.datasets.schemas import DatasetSchema + from superset.models.core import Database + from superset.tables.models import Table + from superset.tables.schemas import TableSchema + + engine = session.get_bind() + Dataset.metadata.create_all(engine) # pylint: disable=no-member + + columns = [ + TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), + TableColumn(column_name="user_id", type="INTEGER"), + TableColumn(column_name="revenue", type="INTEGER"), + TableColumn(column_name="expenses", type="INTEGER"), + TableColumn( + column_name="profit", type="INTEGER", expression="revenue-expenses" + ), + ] + metrics = [ + SqlMetric(metric_name="cnt", expression="COUNT(*)"), + ] + + sqla_table = SqlaTable( + table_name="old_dataset", + columns=columns, + metrics=metrics, + main_dttm_col="ds", + default_endpoint="https://www.youtube.com/watch?v=dQw4w9WgXcQ", # not used + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + offset=-8, + description="This is the description", + is_featured=1, + cache_timeout=3600, + schema="my_schema", + sql=None, + params=json.dumps( + {"remote_id": 64, "database_name": "examples", "import_time": 1606677834,} + ), + perm=None, + filter_select_enabled=1, + fetch_values_predicate="foo IN (1, 2)", + is_sqllab_view=0, # no longer used? + template_params=json.dumps({"answer": "42"}), + schema_perm=None, + extra=json.dumps({"warning_markdown": "*WARNING*"}), + ) + session.add(sqla_table) + session.flush() + + # ignore these keys when comparing results + ignored_keys = {"created_on", "changed_on", "uuid"} + + # check that columns were created + column_schema = ColumnSchema() + column_schemas = [ + {k: v for k, v in column_schema.dump(column).items() if k not in ignored_keys} + for column in session.query(Column).all() + ] + assert column_schemas == [ + { + "changed_by": None, + "created_by": None, + "description": None, + "expression": "ds", + "extra_json": "{}", + "id": 1, + "is_increase_desired": True, + "is_additive": False, + "is_aggregation": False, + "is_partition": False, + "is_physical": True, + "is_spatial": False, + "is_temporal": True, + "name": "ds", + "type": "TIMESTAMP", + "unit": None, + "warning_text": None, + "is_managed_externally": False, + "external_url": None, + }, + { + "changed_by": None, + "created_by": None, + "description": None, + "expression": "user_id", + "extra_json": "{}", + "id": 2, + "is_increase_desired": True, + "is_additive": False, + "is_aggregation": False, + "is_partition": False, + "is_physical": True, + "is_spatial": False, + "is_temporal": False, + "name": "user_id", + "type": "INTEGER", + "unit": None, + "warning_text": None, + "is_managed_externally": False, + "external_url": None, + }, + { + "changed_by": None, + "created_by": None, + "description": None, + "expression": "revenue", + "extra_json": "{}", + "id": 3, + "is_increase_desired": True, + "is_additive": False, + "is_aggregation": False, + "is_partition": False, + "is_physical": True, + "is_spatial": False, + "is_temporal": False, + "name": "revenue", + "type": "INTEGER", + "unit": None, + "warning_text": None, + "is_managed_externally": False, + "external_url": None, + }, + { + "changed_by": None, + "created_by": None, + "description": None, + "expression": "expenses", + "extra_json": "{}", + "id": 4, + "is_increase_desired": True, + "is_additive": False, + "is_aggregation": False, + "is_partition": False, + "is_physical": True, + "is_spatial": False, + "is_temporal": False, + "name": "expenses", + "type": "INTEGER", + "unit": None, + "warning_text": None, + "is_managed_externally": False, + "external_url": None, + }, + { + "changed_by": None, + "created_by": None, + "description": None, + "expression": "revenue-expenses", + "extra_json": "{}", + "id": 5, + "is_increase_desired": True, + "is_additive": False, + "is_aggregation": False, + "is_partition": False, + "is_physical": False, + "is_spatial": False, + "is_temporal": False, + "name": "profit", + "type": "INTEGER", + "unit": None, + "warning_text": None, + "is_managed_externally": False, + "external_url": None, + }, + { + "changed_by": None, + "created_by": None, + "description": None, + "expression": "COUNT(*)", + "extra_json": "{}", + "id": 6, + "is_increase_desired": True, + "is_additive": False, + "is_aggregation": True, + "is_partition": False, + "is_physical": False, + "is_spatial": False, + "is_temporal": False, + "name": "cnt", + "type": "Unknown", + "unit": None, + "warning_text": None, + "is_managed_externally": False, + "external_url": None, + }, + ] + + # check that table was created + table_schema = TableSchema() + tables = [ + {k: v for k, v in table_schema.dump(table).items() if k not in ignored_keys} + for table in session.query(Table).all() + ] + assert tables == [ + { + "extra_json": "{}", + "catalog": None, + "schema": "my_schema", + "name": "old_dataset", + "id": 1, + "database": 1, + "columns": [1, 2, 3, 4], + "created_by": None, + "changed_by": None, + "is_managed_externally": False, + "external_url": None, + } + ] + + # check that dataset was created + dataset_schema = DatasetSchema() + datasets = [ + {k: v for k, v in dataset_schema.dump(dataset).items() if k not in ignored_keys} + for dataset in session.query(Dataset).all() + ] + assert datasets == [ + { + "id": 1, + "sqlatable_id": 1, + "name": "old_dataset", + "changed_by": None, + "created_by": None, + "columns": [1, 2, 3, 4, 5, 6], + "is_physical": True, + "tables": [1], + "extra_json": "{}", + "expression": "old_dataset", + "is_managed_externally": False, + "external_url": None, + } + ] + + +def test_create_virtual_sqlatable( + mocker: MockFixture, app_context: None, session: Session +) -> None: + """ + Test shadow write when creating a new ``SqlaTable``. + + When a new virtual ``SqlaTable`` is created, new models should also be created for + ``Dataset`` and ``Column``. + """ + # patch session + mocker.patch( + "superset.security.SupersetSecurityManager.get_session", return_value=session + ) + + from superset.columns.models import Column + from superset.columns.schemas import ColumnSchema + from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn + from superset.datasets.models import Dataset + from superset.datasets.schemas import DatasetSchema + from superset.models.core import Database + from superset.tables.models import Table + + engine = session.get_bind() + Dataset.metadata.create_all(engine) # pylint: disable=no-member + + # create the ``Table`` that the virtual dataset points to + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + table = Table( + name="some_table", + schema="my_schema", + catalog=None, + database=database, + columns=[ + Column(name="ds", is_temporal=True, type="TIMESTAMP"), + Column(name="user_id", type="INTEGER"), + Column(name="revenue", type="INTEGER"), + Column(name="expenses", type="INTEGER"), + ], + ) + session.add(table) + session.commit() + + # create virtual dataset + columns = [ + TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), + TableColumn(column_name="user_id", type="INTEGER"), + TableColumn(column_name="revenue", type="INTEGER"), + TableColumn(column_name="expenses", type="INTEGER"), + TableColumn( + column_name="profit", type="INTEGER", expression="revenue-expenses" + ), + ] + metrics = [ + SqlMetric(metric_name="cnt", expression="COUNT(*)"), + ] + + sqla_table = SqlaTable( + table_name="old_dataset", + columns=columns, + metrics=metrics, + main_dttm_col="ds", + default_endpoint="https://www.youtube.com/watch?v=dQw4w9WgXcQ", # not used + database=database, + offset=-8, + description="This is the description", + is_featured=1, + cache_timeout=3600, + schema="my_schema", + sql=""" +SELECT + ds, + user_id, + revenue, + expenses, + revenue - expenses AS profit +FROM + some_table""", + params=json.dumps( + {"remote_id": 64, "database_name": "examples", "import_time": 1606677834,} + ), + perm=None, + filter_select_enabled=1, + fetch_values_predicate="foo IN (1, 2)", + is_sqllab_view=0, # no longer used? + template_params=json.dumps({"answer": "42"}), + schema_perm=None, + extra=json.dumps({"warning_markdown": "*WARNING*"}), + ) + session.add(sqla_table) + session.flush() + + # ignore these keys when comparing results + ignored_keys = {"created_on", "changed_on", "uuid"} + + # check that columns were created + column_schema = ColumnSchema() + column_schemas = [ + {k: v for k, v in column_schema.dump(column).items() if k not in ignored_keys} + for column in session.query(Column).all() + ] + assert column_schemas == [ + { + "type": "TIMESTAMP", + "is_additive": False, + "extra_json": "{}", + "is_partition": False, + "expression": None, + "unit": None, + "warning_text": None, + "created_by": None, + "is_increase_desired": True, + "description": None, + "is_spatial": False, + "name": "ds", + "is_physical": True, + "changed_by": None, + "is_temporal": True, + "id": 1, + "is_aggregation": False, + "external_url": None, + "is_managed_externally": False, + }, + { + "type": "INTEGER", + "is_additive": False, + "extra_json": "{}", + "is_partition": False, + "expression": None, + "unit": None, + "warning_text": None, + "created_by": None, + "is_increase_desired": True, + "description": None, + "is_spatial": False, + "name": "user_id", + "is_physical": True, + "changed_by": None, + "is_temporal": False, + "id": 2, + "is_aggregation": False, + "external_url": None, + "is_managed_externally": False, + }, + { + "type": "INTEGER", + "is_additive": False, + "extra_json": "{}", + "is_partition": False, + "expression": None, + "unit": None, + "warning_text": None, + "created_by": None, + "is_increase_desired": True, + "description": None, + "is_spatial": False, + "name": "revenue", + "is_physical": True, + "changed_by": None, + "is_temporal": False, + "id": 3, + "is_aggregation": False, + "external_url": None, + "is_managed_externally": False, + }, + { + "type": "INTEGER", + "is_additive": False, + "extra_json": "{}", + "is_partition": False, + "expression": None, + "unit": None, + "warning_text": None, + "created_by": None, + "is_increase_desired": True, + "description": None, + "is_spatial": False, + "name": "expenses", + "is_physical": True, + "changed_by": None, + "is_temporal": False, + "id": 4, + "is_aggregation": False, + "external_url": None, + "is_managed_externally": False, + }, + { + "type": "TIMESTAMP", + "is_additive": False, + "extra_json": "{}", + "is_partition": False, + "expression": "ds", + "unit": None, + "warning_text": None, + "created_by": None, + "is_increase_desired": True, + "description": None, + "is_spatial": False, + "name": "ds", + "is_physical": False, + "changed_by": None, + "is_temporal": True, + "id": 5, + "is_aggregation": False, + "external_url": None, + "is_managed_externally": False, + }, + { + "type": "INTEGER", + "is_additive": False, + "extra_json": "{}", + "is_partition": False, + "expression": "user_id", + "unit": None, + "warning_text": None, + "created_by": None, + "is_increase_desired": True, + "description": None, + "is_spatial": False, + "name": "user_id", + "is_physical": False, + "changed_by": None, + "is_temporal": False, + "id": 6, + "is_aggregation": False, + "external_url": None, + "is_managed_externally": False, + }, + { + "type": "INTEGER", + "is_additive": False, + "extra_json": "{}", + "is_partition": False, + "expression": "revenue", + "unit": None, + "warning_text": None, + "created_by": None, + "is_increase_desired": True, + "description": None, + "is_spatial": False, + "name": "revenue", + "is_physical": False, + "changed_by": None, + "is_temporal": False, + "id": 7, + "is_aggregation": False, + "external_url": None, + "is_managed_externally": False, + }, + { + "type": "INTEGER", + "is_additive": False, + "extra_json": "{}", + "is_partition": False, + "expression": "expenses", + "unit": None, + "warning_text": None, + "created_by": None, + "is_increase_desired": True, + "description": None, + "is_spatial": False, + "name": "expenses", + "is_physical": False, + "changed_by": None, + "is_temporal": False, + "id": 8, + "is_aggregation": False, + "external_url": None, + "is_managed_externally": False, + }, + { + "type": "INTEGER", + "is_additive": False, + "extra_json": "{}", + "is_partition": False, + "expression": "revenue-expenses", + "unit": None, + "warning_text": None, + "created_by": None, + "is_increase_desired": True, + "description": None, + "is_spatial": False, + "name": "profit", + "is_physical": False, + "changed_by": None, + "is_temporal": False, + "id": 9, + "is_aggregation": False, + "external_url": None, + "is_managed_externally": False, + }, + { + "type": "Unknown", + "is_additive": False, + "extra_json": "{}", + "is_partition": False, + "expression": "COUNT(*)", + "unit": None, + "warning_text": None, + "created_by": None, + "is_increase_desired": True, + "description": None, + "is_spatial": False, + "name": "cnt", + "is_physical": False, + "changed_by": None, + "is_temporal": False, + "id": 10, + "is_aggregation": True, + "external_url": None, + "is_managed_externally": False, + }, + ] + + # check that dataset was created, and has a reference to the table + dataset_schema = DatasetSchema() + datasets = [ + {k: v for k, v in dataset_schema.dump(dataset).items() if k not in ignored_keys} + for dataset in session.query(Dataset).all() + ] + assert datasets == [ + { + "id": 1, + "sqlatable_id": 1, + "name": "old_dataset", + "changed_by": None, + "created_by": None, + "columns": [5, 6, 7, 8, 9, 10], + "is_physical": False, + "tables": [1], + "extra_json": "{}", + "external_url": None, + "is_managed_externally": False, + "expression": """ +SELECT + ds, + user_id, + revenue, + expenses, + revenue - expenses AS profit +FROM + some_table""", + } + ] + + +def test_delete_sqlatable(app_context: None, session: Session) -> None: + """ + Test that deleting a ``SqlaTable`` also deletes the corresponding ``Dataset``. + """ + from superset.columns.models import Column + from superset.connectors.sqla.models import SqlaTable, TableColumn + from superset.datasets.models import Dataset + from superset.models.core import Database + from superset.tables.models import Table + + engine = session.get_bind() + Dataset.metadata.create_all(engine) # pylint: disable=no-member + + columns = [ + TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), + ] + sqla_table = SqlaTable( + table_name="old_dataset", + columns=columns, + metrics=[], + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + ) + session.add(sqla_table) + session.flush() + + datasets = session.query(Dataset).all() + assert len(datasets) == 1 + + session.delete(sqla_table) + session.flush() + + # test that dataset was also deleted + datasets = session.query(Dataset).all() + assert len(datasets) == 0 + + +def test_update_sqlatable( + mocker: MockFixture, app_context: None, session: Session +) -> None: + """ + Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``. + """ + # patch session + mocker.patch( + "superset.security.SupersetSecurityManager.get_session", return_value=session + ) + + from superset.columns.models import Column + from superset.connectors.sqla.models import SqlaTable, TableColumn + from superset.datasets.models import Dataset + from superset.models.core import Database + from superset.tables.models import Table + + engine = session.get_bind() + Dataset.metadata.create_all(engine) # pylint: disable=no-member + + columns = [ + TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), + ] + sqla_table = SqlaTable( + table_name="old_dataset", + columns=columns, + metrics=[], + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + ) + session.add(sqla_table) + session.flush() + + dataset = session.query(Dataset).one() + assert len(dataset.columns) == 1 + + # add a column to the original ``SqlaTable`` instance + sqla_table.columns.append(TableColumn(column_name="user_id", type="INTEGER")) + session.flush() + + # check that the column was added to the dataset + dataset = session.query(Dataset).one() + assert len(dataset.columns) == 2 + + # delete the column in the original instance + sqla_table.columns = sqla_table.columns[1:] + session.flush() + + # check that the column was also removed from the dataset + dataset = session.query(Dataset).one() + assert len(dataset.columns) == 1 + + # modify the attribute in a column + sqla_table.columns[0].is_dttm = True + session.flush() + + # check that the dataset column was modified + dataset = session.query(Dataset).one() + assert dataset.columns[0].is_temporal is True + + +def test_update_sqlatable_schema( + mocker: MockFixture, app_context: None, session: Session +) -> None: + """ + Test that updating a ``SqlaTable`` schema also updates the corresponding ``Dataset``. + """ + # patch session + mocker.patch( + "superset.security.SupersetSecurityManager.get_session", return_value=session + ) + mocker.patch("superset.datasets.dao.db.session", session) + + from superset.columns.models import Column + from superset.connectors.sqla.models import SqlaTable, TableColumn + from superset.datasets.models import Dataset + from superset.models.core import Database + from superset.tables.models import Table + + engine = session.get_bind() + Dataset.metadata.create_all(engine) # pylint: disable=no-member + + columns = [ + TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), + ] + sqla_table = SqlaTable( + table_name="old_dataset", + schema="old_schema", + columns=columns, + metrics=[], + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + ) + session.add(sqla_table) + session.flush() + + dataset = session.query(Dataset).one() + assert dataset.tables[0].schema == "old_schema" + assert dataset.tables[0].id == 1 + + sqla_table.schema = "new_schema" + session.flush() + + dataset = session.query(Dataset).one() + assert dataset.tables[0].schema == "new_schema" + assert dataset.tables[0].id == 2 + + +def test_update_sqlatable_metric( + mocker: MockFixture, app_context: None, session: Session +) -> None: + """ + Test that updating a ``SqlaTable`` also updates the corresponding ``Dataset``. + + For this test we check that updating the SQL expression in a metric belonging to a + ``SqlaTable`` is reflected in the ``Dataset`` metric. + """ + # patch session + mocker.patch( + "superset.security.SupersetSecurityManager.get_session", return_value=session + ) + + from superset.columns.models import Column + from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn + from superset.datasets.models import Dataset + from superset.models.core import Database + from superset.tables.models import Table + + engine = session.get_bind() + Dataset.metadata.create_all(engine) # pylint: disable=no-member + + columns = [ + TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), + ] + metrics = [ + SqlMetric(metric_name="cnt", expression="COUNT(*)"), + ] + sqla_table = SqlaTable( + table_name="old_dataset", + columns=columns, + metrics=metrics, + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + ) + session.add(sqla_table) + session.flush() + + # check that the metric was created + column = session.query(Column).filter_by(is_physical=False).one() + assert column.expression == "COUNT(*)" + + # change the metric definition + sqla_table.metrics[0].expression = "MAX(ds)" + session.flush() + + assert column.expression == "MAX(ds)" + + +def test_update_virtual_sqlatable_references( + mocker: MockFixture, app_context: None, session: Session +) -> None: + """ + Test that changing the SQL of a virtual ``SqlaTable`` updates ``Dataset``. + + When the SQL is modified the list of referenced tables should be updated in the new + ``Dataset`` model. + """ + # patch session + mocker.patch( + "superset.security.SupersetSecurityManager.get_session", return_value=session + ) + + from superset.columns.models import Column + from superset.connectors.sqla.models import SqlaTable, TableColumn + from superset.datasets.models import Dataset + from superset.models.core import Database + from superset.tables.models import Table + + engine = session.get_bind() + Dataset.metadata.create_all(engine) # pylint: disable=no-member + + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + table1 = Table( + name="table_a", + schema="my_schema", + catalog=None, + database=database, + columns=[Column(name="a", type="INTEGER")], + ) + table2 = Table( + name="table_b", + schema="my_schema", + catalog=None, + database=database, + columns=[Column(name="b", type="INTEGER")], + ) + session.add(table1) + session.add(table2) + session.commit() + + # create virtual dataset + columns = [TableColumn(column_name="a", type="INTEGER")] + + sqla_table = SqlaTable( + table_name="old_dataset", + columns=columns, + database=database, + schema="my_schema", + sql="SELECT a FROM table_a", + ) + session.add(sqla_table) + session.flush() + + # check that new dataset has table1 + dataset = session.query(Dataset).one() + assert dataset.tables == [table1] + + # change SQL + sqla_table.sql = "SELECT a, b FROM table_a JOIN table_b" + session.flush() + + # check that new dataset has both tables + dataset = session.query(Dataset).one() + assert dataset.tables == [table1, table2] + assert dataset.expression == "SELECT a, b FROM table_a JOIN table_b" + + +def test_quote_expressions(app_context: None, session: Session) -> None: + """ + Test that expressions are quoted appropriately in columns and datasets. + """ + from superset.columns.models import Column + from superset.connectors.sqla.models import SqlaTable, TableColumn + from superset.datasets.models import Dataset + from superset.models.core import Database + from superset.tables.models import Table + + engine = session.get_bind() + Dataset.metadata.create_all(engine) # pylint: disable=no-member + + columns = [ + TableColumn(column_name="has space", type="INTEGER"), + TableColumn(column_name="no_need", type="INTEGER"), + ] + + sqla_table = SqlaTable( + table_name="old dataset", + columns=columns, + metrics=[], + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + ) + session.add(sqla_table) + session.flush() + + dataset = session.query(Dataset).one() + assert dataset.expression == '"old dataset"' + assert dataset.columns[0].expression == '"has space"' + assert dataset.columns[1].expression == "no_need" + + +def test_update_physical_sqlatable( + mocker: MockFixture, app_context: None, session: Session +) -> None: + """ + Test updating the table on a physical dataset. + + When updating the table on a physical dataset by pointing it somewhere else (change + in database ID, schema, or table name) we should point the ``Dataset`` to an + existing ``Table`` if possible, and create a new one otherwise. + """ + # patch session + mocker.patch( + "superset.security.SupersetSecurityManager.get_session", return_value=session + ) + mocker.patch("superset.datasets.dao.db.session", session) + + from superset.columns.models import Column + from superset.connectors.sqla.models import SqlaTable, TableColumn + from superset.datasets.models import Dataset + from superset.models.core import Database + from superset.tables.models import Table + from superset.tables.schemas import TableSchema + + engine = session.get_bind() + Dataset.metadata.create_all(engine) # pylint: disable=no-member + + columns = [ + TableColumn(column_name="a", type="INTEGER"), + ] + + sqla_table = SqlaTable( + table_name="old_dataset", + columns=columns, + metrics=[], + database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"), + ) + session.add(sqla_table) + session.flush() + + # check that the table was created, and that the created dataset points to it + table = session.query(Table).one() + assert table.id == 1 + assert table.name == "old_dataset" + assert table.schema is None + assert table.database_id == 1 + + dataset = session.query(Dataset).one() + assert dataset.tables == [table] + + # point ``SqlaTable`` to a different database + new_database = Database( + database_name="my_other_database", sqlalchemy_uri="sqlite://" + ) + session.add(new_database) + session.flush() + sqla_table.database = new_database + session.flush() + + # ignore these keys when comparing results + ignored_keys = {"created_on", "changed_on", "uuid"} + + # check that the old table still exists, and that the dataset points to the newly + # created table (id=2) and column (id=2), on the new database (also id=2) + table_schema = TableSchema() + tables = [ + {k: v for k, v in table_schema.dump(table).items() if k not in ignored_keys} + for table in session.query(Table).all() + ] + assert tables == [ + { + "created_by": None, + "extra_json": "{}", + "name": "old_dataset", + "changed_by": None, + "catalog": None, + "columns": [1], + "database": 1, + "external_url": None, + "schema": None, + "id": 1, + "is_managed_externally": False, + }, + { + "created_by": None, + "extra_json": "{}", + "name": "old_dataset", + "changed_by": None, + "catalog": None, + "columns": [2], + "database": 2, + "external_url": None, + "schema": None, + "id": 2, + "is_managed_externally": False, + }, + ] + + # check that dataset now points to the new table + assert dataset.tables[0].database_id == 2 + + # point ``SqlaTable`` back + sqla_table.database_id = 1 + session.flush() + + # check that dataset points to the original table + assert dataset.tables[0].database_id == 1 diff --git a/tests/unit_tests/tables/__init__.py b/tests/unit_tests/tables/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/unit_tests/tables/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/tables/test_models.py b/tests/unit_tests/tables/test_models.py new file mode 100644 index 0000000000000..eb1f5f4611248 --- /dev/null +++ b/tests/unit_tests/tables/test_models.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel, unused-argument + +from sqlalchemy.orm.session import Session + + +def test_table_model(app_context: None, session: Session) -> None: + """ + Test basic attributes of a ``Table``. + """ + from superset.columns.models import Column + from superset.models.core import Database + from superset.tables.models import Table + + engine = session.get_bind() + Table.metadata.create_all(engine) # pylint: disable=no-member + + table = Table( + name="my_table", + schema="my_schema", + catalog="my_catalog", + database=Database(database_name="my_database", sqlalchemy_uri="test://"), + columns=[Column(name="ds", type="TIMESTAMP", expression="ds",)], + ) + session.add(table) + session.flush() + + assert table.id == 1 + assert table.uuid is not None + assert table.database_id == 1 + assert table.catalog == "my_catalog" + assert table.schema == "my_schema" + assert table.name == "my_table" + assert [column.name for column in table.columns] == ["ds"]