Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(dashboard_rbac): dashboard extra jwt #13773

Closed
wants to merge 19 commits into from
Closed
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions superset-frontend/src/chart/chartAction.js
Original file line number Diff line number Diff line change
Expand Up @@ -519,8 +519,8 @@ export function redirectSQLLab(formData) {
export function refreshChart(chartKey, force, dashboardId) {
return (dispatch, getState) => {
const chart = (getState().charts || {})[chartKey];
const timeout = getState().dashboardInfo.common.conf
.SUPERSET_WEBSERVER_TIMEOUT;
const { dashboardInfo } = getState();
const timeout = dashboardInfo.common.conf.SUPERSET_WEBSERVER_TIMEOUT;

if (
!chart.latestQueryFormData ||
Expand Down
1 change: 1 addition & 0 deletions superset-frontend/src/dashboard/containers/Chart.jsx
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ function mapStateToProps(
});

formData.dashboardId = dashboardInfo.id;
formData.extra_jwt = dashboardInfo.extraJwt;

return {
chart,
Expand Down
10 changes: 9 additions & 1 deletion superset-frontend/src/dashboard/reducers/getInitialState.js
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,14 @@ import newComponentFactory from '../util/newComponentFactory';
import { TIME_RANGE } from '../../visualizations/FilterBox/FilterBox';

export default function getInitialState(bootstrapData) {
const { user_id, datasources, common, editMode, urlParams } = bootstrapData;
const {
user_id,
datasources,
common,
editMode,
urlParams,
extra_jwt,
} = bootstrapData;

const dashboard = { ...bootstrapData.dashboard_data };
let preselectFilters = {};
Expand Down Expand Up @@ -283,6 +290,7 @@ export default function getInitialState(bootstrapData) {
conf: common.conf,
},
lastModifiedTime: dashboard.last_modified_time,
extraJwt: extra_jwt,
},
dashboardFilters,
nativeFilters,
Expand Down
6 changes: 6 additions & 0 deletions superset/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
cache_manager,
celery_app,
csrf,
dashboard_jwt_manager,
db,
feature_flag_manager,
machine_auth_provider_factory,
Expand Down Expand Up @@ -534,6 +535,7 @@ def init_app_in_ctx(self) -> None:
self.configure_data_sources()
self.configure_auth_provider()
self.configure_async_queries()
self.configure_dashboard_jwt()

# Hook that provides administrators a handle on the Flask APP
# after initialization
Expand Down Expand Up @@ -698,3 +700,7 @@ def register_blueprints(self) -> None:

def setup_bundle_manifest(self) -> None:
manifest_processor.init_app(self.flask_app)

def configure_dashboard_jwt(self):
if feature_flag_manager.is_feature_enabled("DASHBOARD_RBAC"):
dashboard_jwt_manager.init_app(self.flask_app)
5 changes: 5 additions & 0 deletions superset/charts/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,11 @@ class ChartDataQueryContextSchema(Schema):

result_type = EnumField(ChartDataResultType, by_value=True)
result_format = EnumField(ChartDataResultFormat, by_value=True)
extra_jwt = fields.String(required=False, allow_none=True,
description="represents a security jwt that was "
"originally generated by the backend in "
"order to allow temporary access for that "
"chart data")

# pylint: disable=no-self-use,unused-argument
@post_load
Expand Down
3 changes: 3 additions & 0 deletions superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class QueryContext:
custom_cache_timeout: Optional[int]
result_type: ChartDataResultType
result_format: ChartDataResultFormat
extra_jwt: str

# TODO: Type datasource and query_object dictionary with TypedDict when it becomes
# a vanilla python type https://github.com/python/mypy/issues/5288
Expand All @@ -80,6 +81,7 @@ def __init__( # pylint: disable=too-many-arguments
custom_cache_timeout: Optional[int] = None,
result_type: Optional[ChartDataResultType] = None,
result_format: Optional[ChartDataResultFormat] = None,
extra_jwt: str = None,
) -> None:
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
Expand All @@ -95,6 +97,7 @@ def __init__( # pylint: disable=too-many-arguments
"result_type": self.result_type,
"result_format": self.result_format,
}
self.extra_jwt = extra_jwt

def get_query_result(self, query_object: QueryObject) -> Dict[str, Any]:
"""Returns a pandas dataframe based on the query object"""
Expand Down
2 changes: 2 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1123,6 +1123,8 @@ class CeleryConfig: # pylint: disable=too-few-public-methods
SQLALCHEMY_DOCS_URL = "https://docs.sqlalchemy.org/en/13/core/engines.html"
SQLALCHEMY_DISPLAY_TEXT = "SQLAlchemy docs"

DASHBOARD_JWT_SECRET = "my secret key replace me"

# -------------------------------------------------------------------
# * WARNING: STOP EDITING HERE *
# -------------------------------------------------------------------
Expand Down
2 changes: 2 additions & 0 deletions superset/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from superset.utils.async_query_manager import AsyncQueryManager
from superset.utils.cache_manager import CacheManager
from superset.utils.dashboard_jwt_manager import DashboardJwtManager
from superset.utils.feature_flag_manager import FeatureFlagManager
from superset.utils.machine_auth import MachineAuthProviderFactory

Expand Down Expand Up @@ -100,6 +101,7 @@ def get_manifest_files(self, bundle: str, asset_type: str) -> List[str]:
appbuilder = AppBuilder(update_perms=False)
async_query_manager = AsyncQueryManager()
cache_manager = CacheManager()
dashboard_jwt_manager = DashboardJwtManager()
celery_app = celery.Celery()
csrf = CSRFProtect()
db = SQLA()
Expand Down
10 changes: 9 additions & 1 deletion superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
from superset.constants import RouteMethod
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
from superset.extensions import dashboard_jwt_manager
from superset.utils.core import DatasourceName, RowLevelSecurityFilterType

if TYPE_CHECKING:
Expand Down Expand Up @@ -987,15 +988,22 @@ def raise_for_access( # pylint: disable=too-many-arguments,too-many-branches
)

if datasource or query_context or viz:
extra_jwt=None
if query_context:
datasource = query_context.datasource
extra_jwt= query_context.extra_jwt
elif viz:
datasource = viz.datasource
extra_jwt = viz.extra_jwt

assert datasource

dashboard_data_context = dashboard_jwt_manager.parse_jwt(extra_jwt)

data_source_allowed_in_dashboard = datasource.id in dashboard_data_context.dataset_ids
if not (
self.can_access_schema(datasource)
data_source_allowed_in_dashboard
or self.can_access_schema(datasource)
or self.can_access("datasource_access", datasource.perm or "")
):
raise SupersetSecurityException(
Expand Down
36 changes: 36 additions & 0 deletions superset/utils/dashboard_jwt_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import json
from typing import Any, Dict

import jwt
from flask import Flask


class DashboardJwtDataObject:
id: int
dataset_ids: [int]

def __init__(self, id: int, dataset_ids: [int]) -> None:
super().__init__()
self.id = id
self.dataset_ids = dataset_ids


class DashboardJwtManager:
def __init__(self) -> None:
super().__init__()
self._jwt_secret: str

def init_app(self, app: Flask) -> None:
config = app.config

self._jwt_secret = config["DASHBOARD_JWT_SECRET"]

def generate_jwt(self, data: DashboardJwtDataObject) -> str:
encoded_jwt = jwt.encode(data.__dict__, self._jwt_secret, algorithm="HS256")
return encoded_jwt.decode("utf-8")

def parse_jwt(self, token: str) -> DashboardJwtDataObject:
if token:
data = jwt.decode(token, self._jwt_secret, algorithms=["HS256"])
return DashboardJwtDataObject(data["id"], dataset_ids=data["dataset_ids"])
return {}
16 changes: 15 additions & 1 deletion superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@
SupersetTemplateParamsErrorException,
SupersetTimeoutException,
)
from superset.extensions import async_query_manager, cache_manager
from superset.extensions import (
async_query_manager,
cache_manager,
dashboard_jwt_manager,
)
from superset.jinja_context import get_template_processor
from superset.models.core import Database, FavStar, Log
from superset.models.dashboard import Dashboard
Expand All @@ -102,6 +106,10 @@
from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.cache import etag_cache
from superset.utils.core import ReservedUrlParameters
from superset.utils.dashboard_jwt_manager import (
DashboardJwtDataObject,
DashboardJwtManager,
)
from superset.utils.dates import now_as_float
from superset.utils.decorators import check_dashboard_access
from superset.views.base import (
Expand Down Expand Up @@ -1886,6 +1894,12 @@ def dashboard( # pylint: disable=too-many-locals
"superset_can_csv": superset_can_csv,
"slice_can_edit": slice_can_edit,
},
"extra_jwt": dashboard_jwt_manager.generate_jwt(
DashboardJwtDataObject(
dashboard.id,
list(map(lambda datasource: datasource.id, dashboard.datasources)),
)
),
"datasources": data["datasources"],
}

Expand Down
2 changes: 2 additions & 0 deletions superset/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(
force: bool = False,
force_cached: bool = False,
) -> None:
self.extra_jwt = form_data.get("extra_jwt") or None
if not datasource:
raise QueryObjectValidationError(_("Viz is missing a datasource"))

Expand Down Expand Up @@ -154,6 +155,7 @@ def __init__(

self.applied_filters: List[Dict[str, str]] = []
self.rejected_filters: List[Dict[str, str]] = []
self.extra_jwt: str

@property
def force_cached(self) -> bool:
Expand Down