Skip to content

Commit

Permalink
fix: Use RLS clause instead of ID for cache key (apache#25229)
Browse files Browse the repository at this point in the history
  • Loading branch information
jfrag1 authored Sep 18, 2023
1 parent ef4b6f0 commit 72c0aa3
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 12 deletions.
26 changes: 14 additions & 12 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
if TYPE_CHECKING:
from superset.common.query_context import QueryContext
from superset.connectors.base.models import BaseDatasource
from superset.connectors.sqla.models import SqlaTable
from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.sql_lab import Query
Expand Down Expand Up @@ -2083,28 +2083,30 @@ def get_rls_filters(self, table: "BaseDatasource") -> list[SqlaQuery]:
)
return query.all()

def get_rls_ids(self, table: "BaseDatasource") -> list[int]:
def get_rls_sorted(self, table: "BaseDatasource") -> list["RowLevelSecurityFilter"]:
"""
Retrieves the appropriate row level security filters IDs for the current user
and the passed table.
Retrieves a list RLS filters sorted by ID for
the current user and the passed table.
:param table: The table to check against
:returns: A list of IDs
:returns: A list RLS filters
"""
ids = [f.id for f in self.get_rls_filters(table)]
ids.sort() # Combinations rather than permutations
return ids
filters = self.get_rls_filters(table)
filters.sort(key=lambda f: f.id)
return filters

def get_guest_rls_filters_str(self, table: "BaseDatasource") -> list[str]:
return [f.get("clause", "") for f in self.get_guest_rls_filters(table)]

def get_rls_cache_key(self, datasource: "BaseDatasource") -> list[str]:
rls_ids = []
rls_clauses_with_group_key = []
if datasource.is_rls_supported:
rls_ids = self.get_rls_ids(datasource)
rls_str = [str(rls_id) for rls_id in rls_ids]
rls_clauses_with_group_key = [
f"{f.clause}-{f.group_key or ''}"
for f in self.get_rls_sorted(datasource)
]
guest_rls = self.get_guest_rls_filters_str(datasource)
return guest_rls + rls_str
return guest_rls + rls_clauses_with_group_key

@staticmethod
def _get_current_epoch_time() -> float:
Expand Down
15 changes: 15 additions & 0 deletions tests/integration_tests/security/row_level_security_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,21 @@ def test_rls_filter_doesnt_alter_admin_birth_names_query(self):
assert not self.NAMES_Q_REGEX.search(sql)
assert not self.BASE_FILTER_REGEX.search(sql)

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_rls_cache_key(self):
g.user = self.get_user(username="admin")
tbl = self.get_table(name="birth_names")
clauses = security_manager.get_rls_cache_key(tbl)
assert clauses == []

g.user = self.get_user(username="gamma")
clauses = security_manager.get_rls_cache_key(tbl)
assert clauses == [
"name like 'A%' or name like 'B%'-name",
"name like 'Q%'-name",
"gender = 'boy'-gender",
]


class TestRowLevelSecurityCreateAPI(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
Expand Down

0 comments on commit 72c0aa3

Please sign in to comment.