Skip to content

Commit

Permalink
chore: remove duplicate code in SqlaTable (#28752)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored May 29, 2024
1 parent 020c799 commit 643ee17
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 48 deletions.
45 changes: 1 addition & 44 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import dataclasses
import logging
import re
from collections import defaultdict
from collections.abc import Hashable
from dataclasses import dataclass, field
from datetime import datetime, timedelta
Expand Down Expand Up @@ -70,7 +69,7 @@
from sqlalchemy.sql.expression import Label, TextAsFrom
from sqlalchemy.sql.selectable import Alias, TableClause

from superset import app, db, is_feature_enabled, security_manager
from superset import app, db, security_manager
from superset.commands.dataset.exceptions import DatasetNotFoundError
from superset.common.db_query_status import QueryStatus
from superset.connectors.sqla.utils import (
Expand Down Expand Up @@ -1603,48 +1602,6 @@ def is_alias_used_in_orderby(col: ColumnElement) -> bool:
if is_alias_used_in_orderby(col):
col.name = f"{col.name}__"

def get_sqla_row_level_filters(
self,
template_processor: BaseTemplateProcessor,
) -> list[TextClause]:
"""
Return the appropriate row level security filters for this table and the
current user. A custom username can be passed when the user is not present in the
Flask global namespace.
:param template_processor: The template processor to apply to the filters.
:returns: A list of SQL clauses to be ANDed together.
"""
all_filters: list[TextClause] = []
filter_groups: dict[int | str, list[TextClause]] = defaultdict(list)
try:
for filter_ in security_manager.get_rls_filters(self):
clause = self.text(
f"({template_processor.process_template(filter_.clause)})"
)
if filter_.group_key:
filter_groups[filter_.group_key].append(clause)
else:
all_filters.append(clause)

if is_feature_enabled("EMBEDDED_SUPERSET"):
for rule in security_manager.get_guest_rls_filters(self):
clause = self.text(
f"({template_processor.process_template(rule['clause'])})"
)
all_filters.append(clause)

grouped_filters = [or_(*clauses) for clauses in filter_groups.values()]
all_filters.extend(grouped_filters)
return all_filters
except TemplateError as ex:
raise QueryObjectValidationError(
_(
"Error in jinja expression in RLS filters: %(msg)s",
msg=ex.message,
)
) from ex

def text(self, clause: str) -> TextClause:
return self.db_engine_spec.get_text_clause(clause)

Expand Down
4 changes: 3 additions & 1 deletion superset/models/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,7 @@ def get_fetch_values_predicate(

def get_sqla_row_level_filters(
self,
template_processor: BaseTemplateProcessor,
template_processor: Optional[BaseTemplateProcessor] = None,
) -> list[TextClause]:
"""
Return the appropriate row level security filters for this table and the
Expand All @@ -815,6 +815,8 @@ def get_sqla_row_level_filters(
:param template_processor: The template processor to apply to the filters.
:returns: A list of SQL clauses to be ANDed together.
"""
template_processor = template_processor or self.get_template_processor()

all_filters: list[TextClause] = []
filter_groups: dict[Union[int, str], list[TextClause]] = defaultdict(list)
try:
Expand Down
4 changes: 1 addition & 3 deletions superset/sql_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -1269,10 +1269,8 @@ def get_rls_for_table(
if not dataset:
return None

template_processor = dataset.get_template_processor()
predicate = " AND ".join(
str(filter_)
for filter_ in dataset.get_sqla_row_level_filters(template_processor)
str(filter_) for filter_ in dataset.get_sqla_row_level_filters()
)
if not predicate:
return None
Expand Down

0 comments on commit 643ee17

Please sign in to comment.