Skip to content

Commit

Permalink
fix: add back database lookup from sip 68 revert (#22129)
Browse files Browse the repository at this point in the history
  • Loading branch information
eschutho authored Nov 15, 2022
1 parent e23efef commit 6f6cb18
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 24 deletions.
64 changes: 63 additions & 1 deletion superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,13 @@
from superset.common.utils.time_range_utils import get_since_until_from_time_range
from superset.connectors.base.models import BaseColumn, BaseDatasource, BaseMetric
from superset.connectors.sqla.utils import (
find_cached_objects_in_session,
get_columns_description,
get_physical_table_metadata,
get_virtual_table_metadata,
validate_adhoc_subquery,
)
from superset.datasets.models import Dataset as NewDataset
from superset.db_engine_specs.base import BaseEngineSpec, CTE_ALIAS, TimestampExpression
from superset.exceptions import (
AdvancedDataTypeResponseError,
Expand Down Expand Up @@ -2088,6 +2090,21 @@ def update_column( # pylint: disable=unused-argument
# table is updated. This busts the cache key for all charts that use the table.
session.execute(update(SqlaTable).where(SqlaTable.id == target.table.id))

# TODO: This shadow writing is deprecated
# if table itself has changed, shadow-writing will happen in `after_update` anyway
if target.table not in session.dirty:
dataset: NewDataset = (
session.query(NewDataset)
.filter_by(uuid=target.table.uuid)
.one_or_none()
)
# Update shadow dataset and columns
# did we find the dataset?
if not dataset:
# if dataset is not found create a new copy
target.table.write_shadow_dataset()
return

@staticmethod
def after_insert(
mapper: Mapper,
Expand All @@ -2099,6 +2116,9 @@ def after_insert(
"""
security_manager.dataset_after_insert(mapper, connection, sqla_table)

# TODO: deprecated
sqla_table.write_shadow_dataset()

@staticmethod
def after_delete(
mapper: Mapper,
Expand All @@ -2117,11 +2137,53 @@ def after_update(
sqla_table: "SqlaTable",
) -> None:
"""
Update dataset permissions after update
Update dataset permissions
"""
# set permissions
security_manager.dataset_after_update(mapper, connection, sqla_table)

# TODO: the shadow writing is deprecated
inspector = inspect(sqla_table)
session = inspector.session

# double-check that ``UPDATE``s are actually pending (this method is called even
# for instances that have no net changes to their column-based attributes)
if not session.is_modified(sqla_table, include_collections=True):
return

# find the dataset from the known instance list first
# (it could be either from a previous query or newly created)
dataset = next(
find_cached_objects_in_session(
session, NewDataset, uuids=[sqla_table.uuid]
),
None,
)
# if not found, pull from database
if not dataset:
dataset = (
session.query(NewDataset).filter_by(uuid=sqla_table.uuid).one_or_none()
)
if not dataset:
sqla_table.write_shadow_dataset()
return

def write_shadow_dataset(
self: "SqlaTable",
) -> None:
"""
This method is deprecated
"""
session = inspect(self).session
# most of the write_shadow_dataset functionality has been removed
# but leaving this portion in
# to remove later because it is adding a Database relationship to the session
# and there is some functionality that depends on this
if self.database_id and (
not self.database or self.database.id != self.database_id
):
self.database = session.query(Database).filter_by(id=self.database_id).one()


sa.event.listen(SqlaTable, "before_update", SqlaTable.before_update)
sa.event.listen(SqlaTable, "after_update", SqlaTable.after_update)
Expand Down
24 changes: 1 addition & 23 deletions superset/datasets/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional

from flask_appbuilder.models.sqla.interface import SQLAInterface
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import joinedload

from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.dao.base import BaseDAO
Expand All @@ -37,26 +35,6 @@ class DatasetDAO(BaseDAO): # pylint: disable=too-many-public-methods
model_cls = SqlaTable
base_filter = DatasourceFilter

@classmethod
def find_by_ids(cls, model_ids: Union[List[str], List[int]]) -> List[SqlaTable]:
"""
Find a List of models by a list of ids, if defined applies `base_filter`
"""
id_col = getattr(SqlaTable, cls.id_column_name, None)
if id_col is None:
return []

# the joinedload option ensures that the database is
# available in the session later and not lazy loaded
query = (
db.session.query(SqlaTable)
.options(joinedload(SqlaTable.database))
.filter(id_col.in_(model_ids))
)
data_model = SQLAInterface(SqlaTable, db.session)
query = DatasourceFilter(cls.id_column_name, data_model).apply(query, None)
return query.all()

@staticmethod
def get_database_by_id(database_id: int) -> Optional[Database]:
try:
Expand Down

0 comments on commit 6f6cb18

Please sign in to comment.