Skip to content

Commit

Permalink
fix: cache warmup unable to login (#9597)
Browse files Browse the repository at this point in the history
  • Loading branch information
enskylin committed Jun 15, 2022
1 parent ead1040 commit cfee5ac
Show file tree
Hide file tree
Showing 7 changed files with 81 additions and 271 deletions.
2 changes: 2 additions & 0 deletions docs/docs/installation/cache.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ Superset has a Celery task that will periodically warm up the cache based on dif
To use it, add the following to the `CELERYBEAT_SCHEDULE` section in `config.py`:

```python
SUPERSET_CACHE_WARMUP_USER = "user_with_permission_to_dashboards"

CELERYBEAT_SCHEDULE = {
'cache-warmup-hourly': {
'task': 'cache-warmup',
Expand Down
3 changes: 3 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,9 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]:
"CACHE_NO_NULL_WARNING": True,
}

# Cache warmup user
SUPERSET_CACHE_WARMUP_USER = "admin"

# Time before selenium times out after trying to locate an element on the page and wait
# for that element to load for a screenshot.
SCREENSHOT_LOCATE_WAIT = int(timedelta(seconds=10).total_seconds())
Expand Down
85 changes: 15 additions & 70 deletions superset/tasks/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,73 +14,34 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import logging
from typing import Any, Dict, List, Optional, Union
from urllib import request
from urllib.error import URLError

from celery.utils.log import get_task_logger
from sqlalchemy import and_, func

from superset import app, db
from superset import app, db, security_manager
from superset.extensions import celery_app
from superset.models.core import Log
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.tags import Tag, TaggedObject
from superset.utils.date_parser import parse_human_datetime
from superset.views.utils import build_extra_filters
from superset.utils.webdriver import WebDriverProxy

logger = get_task_logger(__name__)
logger.setLevel(logging.INFO)


def get_form_data(
chart_id: int, dashboard: Optional[Dashboard] = None
) -> Dict[str, Any]:
"""
Build `form_data` for chart GET request from dashboard's `default_filters`.
When a dashboard has `default_filters` they need to be added as extra
filters in the GET request for charts.
"""
form_data: Dict[str, Any] = {"slice_id": chart_id}

if dashboard is None or not dashboard.json_metadata:
return form_data

json_metadata = json.loads(dashboard.json_metadata)
default_filters = json.loads(json_metadata.get("default_filters", "null"))
if not default_filters:
return form_data

filter_scopes = json_metadata.get("filter_scopes", {})
layout = json.loads(dashboard.position_json or "{}")
if (
isinstance(layout, dict)
and isinstance(filter_scopes, dict)
and isinstance(default_filters, dict)
):
extra_filters = build_extra_filters(
layout, filter_scopes, default_filters, chart_id
)
if extra_filters:
form_data["extra_filters"] = extra_filters

return form_data


def get_url(chart: Slice, extra_filters: Optional[Dict[str, Any]] = None) -> str:
def get_dash_url(dashboard: Dashboard) -> str:
"""Return external URL for warming up a given chart/table cache."""
with app.test_request_context():
baseurl = (
"{SUPERSET_WEBSERVER_PROTOCOL}://"
"{SUPERSET_WEBSERVER_ADDRESS}:"
"{SUPERSET_WEBSERVER_PORT}".format(**app.config)
)
return f"{baseurl}{chart.get_explore_url(overrides=extra_filters)}"
return f"{baseurl}{dashboard.url}"


class Strategy: # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -133,9 +94,11 @@ class DummyStrategy(Strategy): # pylint: disable=too-few-public-methods

def get_urls(self) -> List[str]:
session = db.create_scoped_session()
charts = session.query(Slice).all()
dashboards = (
session.query(Dashboard).filter(Dashboard.published.is_(True)).all()
)

return [get_url(chart) for chart in charts]
return [get_dash_url(dashboard) for dashboard in dashboards if dashboard.slices]


class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -164,7 +127,6 @@ def __init__(self, top_n: int = 5, since: str = "7 days ago") -> None:
self.since = parse_human_datetime(since) if since else None

def get_urls(self) -> List[str]:
urls = []
session = db.create_scoped_session()

records = (
Expand All @@ -177,12 +139,8 @@ def get_urls(self) -> List[str]:
)
dash_ids = [record.dashboard_id for record in records]
dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
for dashboard in dashboards:
for chart in dashboard.slices:
form_data_with_filters = get_form_data(chart.id, dashboard)
urls.append(get_url(chart, form_data_with_filters))

return urls
return [get_dash_url(dashboard) for dashboard in dashboards]


class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -228,24 +186,7 @@ def get_urls(self) -> List[str]:
dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids))
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
urls.append(get_url(chart))

# add charts that are tagged
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "chart",
TaggedObject.tag_id.in_(tag_ids),
)
)
.all()
)
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
urls.append(get_url(chart))
urls.append(get_dash_url(dashboard))

return urls

Expand Down Expand Up @@ -283,10 +224,14 @@ def cache_warmup(
return message

results: Dict[str, List[str]] = {"success": [], "errors": []}

user = security_manager.find_user(username=app.config["SUPERSET_CACHE_WARMUP_USER"])
wd = WebDriverProxy(app.config["WEBDRIVER_TYPE"], user=user)

for url in strategy.get_urls():
try:
logger.info("Fetching %s", url)
request.urlopen(url) # pylint: disable=consider-using-with
wd.get_screenshot(url, "grid-container")
results["success"].append(url)
except URLError:
logger.exception("Error warming up cache!")
Expand Down
10 changes: 6 additions & 4 deletions superset/utils/screenshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@ def __init__(self, url: str, digest: str):
self.url = url
self.screenshot: Optional[bytes] = None

def driver(self, window_size: Optional[WindowSize] = None) -> WebDriverProxy:
def driver(
self, window_size: Optional[WindowSize] = None, user: "User" = None
) -> WebDriverProxy:
window_size = window_size or self.window_size
return WebDriverProxy(self.driver_type, window_size)
return WebDriverProxy(self.driver_type, window_size, user)

def cache_key(
self,
Expand All @@ -70,8 +72,8 @@ def cache_key(
def get_screenshot(
self, user: "User", window_size: Optional[WindowSize] = None
) -> Optional[bytes]:
driver = self.driver(window_size)
self.screenshot = driver.get_screenshot(self.url, self.element, user)
driver = self.driver(window_size, user)
self.screenshot = driver.get_screenshot(self.url, self.element)
return self.screenshot

def get(
Expand Down
60 changes: 39 additions & 21 deletions superset/utils/webdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,29 @@ class DashboardStandaloneMode(Enum):


class WebDriverProxy:
def __init__(self, driver_type: str, window: Optional[WindowSize] = None):
def __init__(
self, driver_type: str, window: Optional[WindowSize] = None, user: "User" = None
):
self._driver_type = driver_type
self._window: WindowSize = window or (800, 600)
self._screenshot_locate_wait = current_app.config["SCREENSHOT_LOCATE_WAIT"]
self._screenshot_load_wait = current_app.config["SCREENSHOT_LOAD_WAIT"]

def create(self) -> WebDriver:
self._user = user
self._driver = None

def __del__(self) -> None:
self._destroy()

@property
def driver(self) -> WebDriver:
if not self._driver:
self._driver = self._create()
self._driver.set_window_size(*self._window) # type: ignore
if self._user:
self._auth(self._user)
return self._driver

def _create(self) -> WebDriver:
pixel_density = current_app.config["WEBDRIVER_WINDOW"].get("pixel_density", 1)
if self._driver_type == "firefox":
driver_class = firefox.webdriver.WebDriver
Expand All @@ -83,48 +99,52 @@ def create(self) -> WebDriver:

return driver_class(**kwargs)

def auth(self, user: "User") -> WebDriver:
driver = self.create()
def _auth(self, user: "User") -> WebDriver:
return machine_auth_provider_factory.instance.authenticate_webdriver(
driver, user
self.driver, user
)

@staticmethod
def destroy(driver: WebDriver, tries: int = 2) -> None:
def _destroy(self) -> None:
"""Destroy a driver"""

if not self._driver:
return

# This is some very flaky code in selenium. Hence the retries
# and catch-all exceptions

try:
retry_call(driver.close, max_tries=tries)
retry_call(
self._driver.close,
max_tries=current_app.config["SCREENSHOT_SELENIUM_RETRIES"],
)
except Exception: # pylint: disable=broad-except
pass
try:
driver.quit()
self._driver.quit()
except Exception: # pylint: disable=broad-except
pass

def get_screenshot(
self, url: str, element_name: str, user: "User"
) -> Optional[bytes]:
driver = self.auth(user)
driver.set_window_size(*self._window)
driver.get(url)
self._driver = None

def get_screenshot(self, url: str, element_name: str) -> Optional[bytes]:
self.driver.get(url)
img: Optional[bytes] = None
selenium_headstart = current_app.config["SCREENSHOT_SELENIUM_HEADSTART"]
logger.debug("Sleeping for %i seconds", selenium_headstart)
sleep(selenium_headstart)

try:
logger.debug("Wait for the presence of %s", element_name)
element = WebDriverWait(driver, self._screenshot_locate_wait).until(
element = WebDriverWait(self.driver, self._screenshot_locate_wait).until(
EC.presence_of_element_located((By.CLASS_NAME, element_name))
)
logger.debug("Wait for .loading to be done")
WebDriverWait(driver, self._screenshot_load_wait).until_not(
WebDriverWait(self.driver, self._screenshot_load_wait).until_not(
EC.presence_of_all_elements_located((By.CLASS_NAME, "loading"))
)
logger.debug("Wait for chart to have content")
WebDriverWait(driver, self._screenshot_locate_wait).until(
WebDriverWait(self.driver, self._screenshot_locate_wait).until(
EC.visibility_of_all_elements_located(
(By.CLASS_NAME, "slice_container")
)
Expand All @@ -147,6 +167,4 @@ def get_screenshot(
)
except WebDriverException as ex:
logger.error(ex, exc_info=True)
finally:
self.destroy(driver, current_app.config["SCREENSHOT_SELENIUM_RETRIES"])
return img
Loading

0 comments on commit cfee5ac

Please sign in to comment.