Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: cache warmup unable to login (#9597, #18933) #20387

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/docs/installation/cache.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,28 @@ defined in `DATA_CACHE_CONFIG`.

## Celery beat

Superset has a Celery task that will periodically warm up the cache based on different strategies.
To use it, add the following to the `CELERYBEAT_SCHEDULE` section in `config.py`:

```python
SUPERSET_CACHE_WARMUP_USER = "user_with_permission_to_dashboards"

CELERYBEAT_SCHEDULE = {
'cache-warmup-hourly': {
'task': 'cache-warmup',
'schedule': crontab(minute=0, hour='*'), # hourly
'kwargs': {
'strategy_name': 'top_n_dashboards',
'top_n': 5,
'since': '7 days ago',
},
},
}
```

This will cache all the charts in the top 5 most popular dashboards every hour. For other
strategies, check the `superset/tasks/cache.py` file.

### Caching Thumbnails

This is an optional feature that can be turned on by activating it’s feature flag on config:
Expand Down
3 changes: 3 additions & 0 deletions superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,9 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]:
"CACHE_NO_NULL_WARNING": True,
}

# Cache warmup user
SUPERSET_CACHE_WARMUP_USER = "admin"

# Time before selenium times out after trying to locate an element on the page and wait
# for that element to load for a screenshot.
SCREENSHOT_LOCATE_WAIT = int(timedelta(seconds=10).total_seconds())
Expand Down
85 changes: 15 additions & 70 deletions superset/tasks/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,73 +14,34 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import logging
from typing import Any, Dict, List, Optional, Union
from urllib import request
from urllib.error import URLError

from celery.utils.log import get_task_logger
from sqlalchemy import and_, func

from superset import app, db
from superset import app, db, security_manager
from superset.extensions import celery_app
from superset.models.core import Log
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.tags import Tag, TaggedObject
from superset.utils.date_parser import parse_human_datetime
from superset.views.utils import build_extra_filters
from superset.utils.webdriver import WebDriverProxy

logger = get_task_logger(__name__)
logger.setLevel(logging.INFO)


def get_form_data(
chart_id: int, dashboard: Optional[Dashboard] = None
) -> Dict[str, Any]:
"""
Build `form_data` for chart GET request from dashboard's `default_filters`.

When a dashboard has `default_filters` they need to be added as extra
filters in the GET request for charts.

"""
form_data: Dict[str, Any] = {"slice_id": chart_id}

if dashboard is None or not dashboard.json_metadata:
return form_data

json_metadata = json.loads(dashboard.json_metadata)
default_filters = json.loads(json_metadata.get("default_filters", "null"))
if not default_filters:
return form_data

filter_scopes = json_metadata.get("filter_scopes", {})
layout = json.loads(dashboard.position_json or "{}")
if (
isinstance(layout, dict)
and isinstance(filter_scopes, dict)
and isinstance(default_filters, dict)
):
extra_filters = build_extra_filters(
layout, filter_scopes, default_filters, chart_id
)
if extra_filters:
form_data["extra_filters"] = extra_filters

return form_data


def get_url(chart: Slice, extra_filters: Optional[Dict[str, Any]] = None) -> str:
def get_dash_url(dashboard: Dashboard) -> str:
"""Return external URL for warming up a given chart/table cache."""
with app.test_request_context():
baseurl = (
"{SUPERSET_WEBSERVER_PROTOCOL}://"
"{SUPERSET_WEBSERVER_ADDRESS}:"
"{SUPERSET_WEBSERVER_PORT}".format(**app.config)
)
return f"{baseurl}{chart.get_explore_url(overrides=extra_filters)}"
return f"{baseurl}{dashboard.url}"


class Strategy: # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -133,9 +94,11 @@

def get_urls(self) -> List[str]:
session = db.create_scoped_session()
charts = session.query(Slice).all()
dashboards = (

Check warning on line 97 in superset/tasks/cache.py

View check run for this annotation

Codecov / codecov/patch

superset/tasks/cache.py#L97

Added line #L97 was not covered by tests
session.query(Dashboard).filter(Dashboard.published.is_(True)).all()
)

return [get_url(chart) for chart in charts]
return [get_dash_url(dashboard) for dashboard in dashboards if dashboard.slices]

Check warning on line 101 in superset/tasks/cache.py

View check run for this annotation

Codecov / codecov/patch

superset/tasks/cache.py#L101

Added line #L101 was not covered by tests


class TopNDashboardsStrategy(Strategy): # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -164,7 +127,6 @@
self.since = parse_human_datetime(since) if since else None

def get_urls(self) -> List[str]:
urls = []
session = db.create_scoped_session()

records = (
Expand All @@ -177,12 +139,8 @@
)
dash_ids = [record.dashboard_id for record in records]
dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids)).all()
for dashboard in dashboards:
for chart in dashboard.slices:
form_data_with_filters = get_form_data(chart.id, dashboard)
urls.append(get_url(chart, form_data_with_filters))

return urls
return [get_dash_url(dashboard) for dashboard in dashboards]


class DashboardTagsStrategy(Strategy): # pylint: disable=too-few-public-methods
Expand Down Expand Up @@ -228,24 +186,7 @@
dash_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_dashboards = session.query(Dashboard).filter(Dashboard.id.in_(dash_ids))
for dashboard in tagged_dashboards:
for chart in dashboard.slices:
urls.append(get_url(chart))

# add charts that are tagged
tagged_objects = (
session.query(TaggedObject)
.filter(
and_(
TaggedObject.object_type == "chart",
TaggedObject.tag_id.in_(tag_ids),
)
)
.all()
)
chart_ids = [tagged_object.object_id for tagged_object in tagged_objects]
tagged_charts = session.query(Slice).filter(Slice.id.in_(chart_ids))
for chart in tagged_charts:
urls.append(get_url(chart))
urls.append(get_dash_url(dashboard))

return urls

Expand Down Expand Up @@ -283,10 +224,14 @@
return message

results: Dict[str, List[str]] = {"success": [], "errors": []}

user = security_manager.find_user(username=app.config["SUPERSET_CACHE_WARMUP_USER"])
wd = WebDriverProxy(app.config["WEBDRIVER_TYPE"], user=user)

Check warning on line 229 in superset/tasks/cache.py

View check run for this annotation

Codecov / codecov/patch

superset/tasks/cache.py#L228-L229

Added lines #L228 - L229 were not covered by tests

for url in strategy.get_urls():
try:
logger.info("Fetching %s", url)
request.urlopen(url) # pylint: disable=consider-using-with
wd.get_screenshot(url, "grid-container")

Check warning on line 234 in superset/tasks/cache.py

View check run for this annotation

Codecov / codecov/patch

superset/tasks/cache.py#L234

Added line #L234 was not covered by tests
results["success"].append(url)
except URLError:
logger.exception("Error warming up cache!")
Expand Down
10 changes: 6 additions & 4 deletions superset/utils/screenshots.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,11 @@
self.url = url
self.screenshot: Optional[bytes] = None

def driver(self, window_size: Optional[WindowSize] = None) -> WebDriverProxy:
def driver(
self, window_size: Optional[WindowSize] = None, user: "User" = None
) -> WebDriverProxy:
window_size = window_size or self.window_size
return WebDriverProxy(self.driver_type, window_size)
return WebDriverProxy(self.driver_type, window_size, user)

Check warning on line 54 in superset/utils/screenshots.py

View check run for this annotation

Codecov / codecov/patch

superset/utils/screenshots.py#L54

Added line #L54 was not covered by tests

def cache_key(
self,
Expand All @@ -70,8 +72,8 @@
def get_screenshot(
self, user: "User", window_size: Optional[WindowSize] = None
) -> Optional[bytes]:
driver = self.driver(window_size)
self.screenshot = driver.get_screenshot(self.url, self.element, user)
driver = self.driver(window_size, user)
self.screenshot = driver.get_screenshot(self.url, self.element)

Check warning on line 76 in superset/utils/screenshots.py

View check run for this annotation

Codecov / codecov/patch

superset/utils/screenshots.py#L75-L76

Added lines #L75 - L76 were not covered by tests
return self.screenshot

def get(
Expand Down
60 changes: 39 additions & 21 deletions superset/utils/webdriver.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,29 @@


class WebDriverProxy:
def __init__(self, driver_type: str, window: Optional[WindowSize] = None):
def __init__(
self, driver_type: str, window: Optional[WindowSize] = None, user: "User" = None
):
self._driver_type = driver_type
self._window: WindowSize = window or (800, 600)
self._screenshot_locate_wait = current_app.config["SCREENSHOT_LOCATE_WAIT"]
self._screenshot_load_wait = current_app.config["SCREENSHOT_LOAD_WAIT"]

def create(self) -> WebDriver:
self._user = user
self._driver = None

def __del__(self) -> None:
self._destroy()

@property
def driver(self) -> WebDriver:
if not self._driver:
self._driver = self._create()
self._driver.set_window_size(*self._window) # type: ignore
if self._user:
self._auth(self._user)
return self._driver

def _create(self) -> WebDriver:
pixel_density = current_app.config["WEBDRIVER_WINDOW"].get("pixel_density", 1)
if self._driver_type == "firefox":
driver_class = firefox.webdriver.WebDriver
Expand All @@ -83,48 +99,52 @@

return driver_class(**kwargs)

def auth(self, user: "User") -> WebDriver:
driver = self.create()
def _auth(self, user: "User") -> WebDriver:
return machine_auth_provider_factory.instance.authenticate_webdriver(
driver, user
self.driver, user
)

@staticmethod
def destroy(driver: WebDriver, tries: int = 2) -> None:
def _destroy(self) -> None:
"""Destroy a driver"""

if not self._driver:
return

Check warning on line 111 in superset/utils/webdriver.py

View check run for this annotation

Codecov / codecov/patch

superset/utils/webdriver.py#L111

Added line #L111 was not covered by tests

# This is some very flaky code in selenium. Hence the retries
# and catch-all exceptions

try:
retry_call(driver.close, max_tries=tries)
retry_call(
self._driver.close,
max_tries=current_app.config["SCREENSHOT_SELENIUM_RETRIES"],
)
except Exception: # pylint: disable=broad-except
pass
try:
driver.quit()
self._driver.quit()
except Exception: # pylint: disable=broad-except
pass

def get_screenshot(
self, url: str, element_name: str, user: "User"
) -> Optional[bytes]:
driver = self.auth(user)
driver.set_window_size(*self._window)
driver.get(url)
self._driver = None

def get_screenshot(self, url: str, element_name: str) -> Optional[bytes]:
self.driver.get(url)
img: Optional[bytes] = None
selenium_headstart = current_app.config["SCREENSHOT_SELENIUM_HEADSTART"]
logger.debug("Sleeping for %i seconds", selenium_headstart)
sleep(selenium_headstart)

try:
logger.debug("Wait for the presence of %s", element_name)
element = WebDriverWait(driver, self._screenshot_locate_wait).until(
element = WebDriverWait(self.driver, self._screenshot_locate_wait).until(
EC.presence_of_element_located((By.CLASS_NAME, element_name))
)
logger.debug("Wait for .loading to be done")
WebDriverWait(driver, self._screenshot_load_wait).until_not(
WebDriverWait(self.driver, self._screenshot_load_wait).until_not(
EC.presence_of_all_elements_located((By.CLASS_NAME, "loading"))
)
logger.debug("Wait for chart to have content")
WebDriverWait(driver, self._screenshot_locate_wait).until(
WebDriverWait(self.driver, self._screenshot_locate_wait).until(
EC.visibility_of_all_elements_located(
(By.CLASS_NAME, "slice_container")
)
Expand All @@ -147,6 +167,4 @@
)
except WebDriverException as ex:
logger.error(ex, exc_info=True)
finally:
self.destroy(driver, current_app.config["SCREENSHOT_SELENIUM_RETRIES"])
return img
Loading