diff --git a/UPDATING.md b/UPDATING.md index b65c19d6c712a..a18974bb33317 100644 --- a/UPDATING.md +++ b/UPDATING.md @@ -23,6 +23,8 @@ assists people when migrating to a new version. ## Next +* [10567](https://github.com/apache/incubator-superset/pull/10567): Default WEBDRIVER_OPTION_ARGS are Chrome-specific. If you're using FF, should be `--headless` only + * [10241](https://github.com/apache/incubator-superset/pull/10241): change on Alpha role, users started to have access to "Annotation Layers", "Css Templates" and "Import Dashboards". * [10324](https://github.com/apache/incubator-superset/pull/10324): Facebook Prophet has been introduced as an optional dependency to add support for timeseries forecasting in the chart data API. To enable this feature, install Superset with the optional dependency `prophet` or directly `pip install fbprophet`. diff --git a/scripts/tests/run.sh b/scripts/tests/run.sh index 98206f458745b..95c609a810ffb 100755 --- a/scripts/tests/run.sh +++ b/scripts/tests/run.sh @@ -26,8 +26,8 @@ function reset_db() { echo -------------------- echo Reseting test DB echo -------------------- - docker-compose stop superset-tests-worker - RESET_DB_CMD="psql \"postgresql://superset:superset@127.0.0.1:5432\" <<-EOF + docker-compose stop superset-tests-worker superset || true + RESET_DB_CMD="psql \"postgresql://${DB_USER}:${DB_PASSWORD}@127.0.0.1:5432\" <<-EOF DROP DATABASE IF EXISTS ${DB_NAME}; CREATE DATABASE ${DB_NAME}; \\c ${DB_NAME} @@ -53,10 +53,6 @@ function test_init() { echo Superset init echo -------------------- superset init - echo -------------------- - echo Load examples - echo -------------------- - pytest -s tests/load_examples_test.py } # @@ -142,5 +138,5 @@ fi if [ $RUN_TESTS -eq 1 ] then - pytest -x -s --ignore=load_examples_test "${TEST_MODULE}" + pytest -x -s "${TEST_MODULE}" fi diff --git a/superset/app.py b/superset/app.py index b64ca69c5b64d..11cb004daaa1c 100644 --- a/superset/app.py +++ b/superset/app.py @@ -36,6 +36,7 @@ db, feature_flag_manager, jinja_context_manager, + machine_auth_provider_factory, manifest_processor, migrate, results_backend_manager, @@ -468,6 +469,7 @@ def init_app_in_ctx(self) -> None: self.configure_fab() self.configure_url_map_converters() self.configure_data_sources() + self.configure_auth_provider() # Hook that provides administrators a handle on the Flask APP # after initialization @@ -499,6 +501,9 @@ def init_app(self) -> None: self.post_init() + def configure_auth_provider(self) -> None: + machine_auth_provider_factory.init_app(self.flask_app) + def setup_event_logger(self) -> None: _event_logger["event_logger"] = get_event_logger_from_cfg_value( self.flask_app.config.get("EVENT_LOGGER", DBEventLogger()) diff --git a/superset/config.py b/superset/config.py index cfef8c2af70ba..ff4796d6d455b 100644 --- a/superset/config.py +++ b/superset/config.py @@ -761,6 +761,11 @@ class CeleryConfig: # pylint: disable=too-few-public-methods # * Emails are sent using dry-run mode (logging only) SCHEDULED_EMAIL_DEBUG_MODE = False +# This auth provider is used by background (offline) tasks that need to access +# protected resources. Can be overridden by end users in order to support +# custom auth mechanisms +MACHINE_AUTH_PROVIDER_CLASS = "superset.utils.machine_auth.MachineAuthProvider" + # Email reports - minimum time resolution (in minutes) for the crontab EMAIL_REPORTS_CRON_RESOLUTION = 15 @@ -795,9 +800,22 @@ class CeleryConfig: # pylint: disable=too-few-public-methods # Window size - this will impact the rendering of the data WEBDRIVER_WINDOW = {"dashboard": (1600, 2000), "slice": (3000, 1200)} +# An optional override to the default auth hook used to provide auth to the +# offline webdriver +WEBDRIVER_AUTH_FUNC = None + # Any config options to be passed as-is to the webdriver WEBDRIVER_CONFIGURATION: Dict[Any, Any] = {} +# Additional args to be passed as arguments to the config object +# Note: these options are Chrome-specific. For FF, these should +# only include the "--headless" arg +WEBDRIVER_OPTION_ARGS = [ + "--force-device-scale-factor=2.0", + "--high-dpi-support=2.0", + "--headless", +] + # The base URL to query for accessing the user interface WEBDRIVER_BASEURL = "http://0.0.0.0:8080/" # The base URL for the email report hyperlinks. diff --git a/superset/extensions.py b/superset/extensions.py index 7cafef61a4aad..06d55c8a17247 100644 --- a/superset/extensions.py +++ b/superset/extensions.py @@ -34,6 +34,7 @@ from superset.utils.cache_manager import CacheManager from superset.utils.feature_flag_manager import FeatureFlagManager +from superset.utils.machine_auth import MachineAuthProviderFactory if TYPE_CHECKING: from superset.jinja_context import ( # pylint: disable=unused-import @@ -139,6 +140,7 @@ def get_manifest_files(self, bundle: str, asset_type: str) -> List[str]: event_logger = LocalProxy(lambda: _event_logger.get("event_logger")) feature_flag_manager = FeatureFlagManager() jinja_context_manager = JinjaContextManager() +machine_auth_provider_factory = MachineAuthProviderFactory() manifest_processor = UIManifestProcessor(APP_DIR) migrate = Migrate() results_backend_manager = ResultsBackendManager() diff --git a/superset/tasks/schedules.py b/superset/tasks/schedules.py index 9ebdcfec3170f..c38f261097c17 100644 --- a/superset/tasks/schedules.py +++ b/superset/tasks/schedules.py @@ -28,7 +28,6 @@ Callable, Dict, Iterator, - List, NamedTuple, Optional, Tuple, @@ -42,17 +41,16 @@ import simplejson as json from celery.app.task import Task from dateutil.tz import tzlocal -from flask import current_app, render_template, Response, session, url_for +from flask import current_app, render_template, url_for from flask_babel import gettext as __ -from flask_login import login_user from retry.api import retry_call from selenium.common.exceptions import WebDriverException from selenium.webdriver import chrome, firefox +from selenium.webdriver.remote.webdriver import WebDriver from sqlalchemy.exc import NoSuchColumnError, ResourceClosedError -from werkzeug.http import parse_cookie from superset import app, db, security_manager, thumbnail_cache -from superset.extensions import celery_app +from superset.extensions import celery_app, machine_auth_provider_factory from superset.models.alerts import Alert, AlertLog from superset.models.core import Database from superset.models.dashboard import Dashboard @@ -66,7 +64,7 @@ from superset.sql_parse import ParsedQuery from superset.tasks.slack_util import deliver_slack_msg from superset.utils.core import get_email_address_list, send_email_smtp -from superset.utils.screenshots import ChartScreenshot +from superset.utils.screenshots import ChartScreenshot, WebDriverProxy from superset.utils.urls import get_url_path # pylint: disable=too-few-public-methods @@ -74,6 +72,7 @@ if TYPE_CHECKING: # pylint: disable=unused-import from werkzeug.datastructures import TypeConversionDict + from flask_appbuilder.security.sqla.models import User # Globals @@ -191,27 +190,6 @@ def _generate_report_content( return ReportContent(body, data, images, slack_message, screenshot) -def _get_auth_cookies() -> List["TypeConversionDict[Any, Any]"]: - # Login with the user specified to get the reports - with app.test_request_context(): - user = security_manager.find_user(config["EMAIL_REPORTS_USER"]) - login_user(user) - - # A mock response object to get the cookie information from - response = Response() - app.session_interface.save_session(app, session, response) - - cookies = [] - - # Set the cookies in the driver - for name, value in response.headers: - if name.lower() == "set-cookie": - cookie = parse_cookie(value) - cookies.append(cookie["session"]) - - return cookies - - def _get_url_path(view: str, user_friendly: bool = False, **kwargs: Any) -> str: with app.test_request_context(): base_url = ( @@ -220,44 +198,14 @@ def _get_url_path(view: str, user_friendly: bool = False, **kwargs: Any) -> str: return urllib.parse.urljoin(str(base_url), url_for(view, **kwargs)) -def create_webdriver() -> Union[ - chrome.webdriver.WebDriver, firefox.webdriver.WebDriver -]: - # Create a webdriver for use in fetching reports - if config["EMAIL_REPORTS_WEBDRIVER"] == "firefox": - driver_class = firefox.webdriver.WebDriver - options = firefox.options.Options() - elif config["EMAIL_REPORTS_WEBDRIVER"] == "chrome": - driver_class = chrome.webdriver.WebDriver - options = chrome.options.Options() - - options.add_argument("--headless") - - # Prepare args for the webdriver init - kwargs = dict(options=options) - kwargs.update(config["WEBDRIVER_CONFIGURATION"]) - - # Initialize the driver - driver = driver_class(**kwargs) - - # Some webdrivers need an initial hit to the welcome URL - # before we set the cookie - welcome_url = _get_url_path("Superset.welcome") - - # Hit the welcome URL and check if we were asked to login - driver.get(welcome_url) - elements = driver.find_elements_by_id("loginbox") - - # This indicates that we were not prompted for a login box. - if not elements: - return driver +def create_webdriver() -> WebDriver: + return WebDriverProxy(driver_type=config["EMAIL_REPORTS_WEBDRIVER"]).auth( + get_reports_user() + ) - # Set the cookies in the driver - for cookie in _get_auth_cookies(): - info = dict(name="session", value=cookie) - driver.add_cookie(info) - return driver +def get_reports_user() -> "User": + return security_manager.find_user(config["EMAIL_REPORTS_USER"]) def destroy_webdriver( @@ -364,12 +312,15 @@ def _get_slice_data(slc: Slice, delivery_type: EmailDeliveryType) -> ReportConte "Superset.slice", slice_id=slc.id, user_friendly=True ) - cookies = {} - for cookie in _get_auth_cookies(): - cookies["session"] = cookie + # Login on behalf of the "reports" user in order to get cookies to deal with auth + auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies( + get_reports_user() + ) + # Build something like "session=cool_sess.val;other-cookie=awesome_other_cookie" + cookie_str = ";".join([f"{key}={val}" for key, val in auth_cookies.items()]) opener = urllib.request.build_opener() - opener.addheaders.append(("Cookie", f"session={cookies['session']}")) + opener.addheaders.append(("Cookie", cookie_str)) response = opener.open(slice_url) if response.getcode() != 200: raise URLError(response.getcode()) diff --git a/superset/tasks/thumbnails.py b/superset/tasks/thumbnails.py index efa704e755dc7..bf7bdc562d0b7 100644 --- a/superset/tasks/thumbnails.py +++ b/superset/tasks/thumbnails.py @@ -18,18 +18,17 @@ """Utility functions used across Superset""" import logging -from typing import Optional, Tuple +from typing import Optional from flask import current_app from superset import app, security_manager, thumbnail_cache from superset.extensions import celery_app from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot +from superset.utils.webdriver import WindowSize logger = logging.getLogger(__name__) -WindowSize = Tuple[int, int] - @celery_app.task(name="cache_chart_thumbnail", soft_time_limit=300) def cache_chart_thumbnail( diff --git a/superset/utils/machine_auth.py b/superset/utils/machine_auth.py new file mode 100644 index 0000000000000..3bc8afa35cdfe --- /dev/null +++ b/superset/utils/machine_auth.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import importlib +import logging +from typing import Callable, Dict, TYPE_CHECKING + +from flask import current_app, Flask, request, Response, session +from flask_login import login_user +from selenium.webdriver.remote.webdriver import WebDriver +from werkzeug.http import parse_cookie + +from superset.utils.urls import headless_url + +logger = logging.getLogger(__name__) + +if TYPE_CHECKING: + # pylint: disable=unused-import + from flask_appbuilder.security.sqla.models import User + + +class MachineAuthProvider: + def __init__( + self, auth_webdriver_func_override: Callable[[WebDriver, "User"], WebDriver] + ): + # This is here in order to allow for the authenticate_webdriver func to be + # overridden via config, as opposed to the entire provider implementation + self._auth_webdriver_func_override = auth_webdriver_func_override + + def authenticate_webdriver(self, driver: WebDriver, user: "User",) -> WebDriver: + """ + Default AuthDriverFuncType type that sets a session cookie flask-login style + :return: The WebDriver passed in (fluent) + """ + # Short-circuit this method if we have an override configured + if self._auth_webdriver_func_override: + return self._auth_webdriver_func_override(driver, user) + + # Setting cookies requires doing a request first + driver.get(headless_url("/login/")) + + if user: + cookies = self.get_auth_cookies(user) + elif request.cookies: + cookies = request.cookies + else: + cookies = {} + + for cookie_name, cookie_val in cookies.items(): + driver.add_cookie(dict(name=cookie_name, value=cookie_val)) + + return driver + + @staticmethod + def get_auth_cookies(user: "User") -> Dict[str, str]: + # Login with the user specified to get the reports + with current_app.test_request_context("/login"): + login_user(user) + # A mock response object to get the cookie information from + response = Response() + current_app.session_interface.save_session(current_app, session, response) + + cookies = {} + + # Grab any "set-cookie" headers from the login response + for name, value in response.headers: + if name.lower() == "set-cookie": + # This yields a MultiDict, which is ordered -- something like + # MultiDict([('session', 'value-we-want), ('HttpOnly', ''), etc... + # Therefore, we just need to grab the first tuple and add it to our + # final dict + cookie = parse_cookie(value) + cookie_tuple = list(cookie.items())[0] + cookies[cookie_tuple[0]] = cookie_tuple[1] + + return cookies + + +class MachineAuthProviderFactory: + def __init__(self) -> None: + self._auth_provider = None + + def init_app(self, app: Flask) -> None: + auth_provider_fqclass = app.config["MACHINE_AUTH_PROVIDER_CLASS"] + auth_provider_classname = auth_provider_fqclass[ + auth_provider_fqclass.rfind(".") + 1 : + ] + auth_provider_module_name = auth_provider_fqclass[ + 0 : auth_provider_fqclass.rfind(".") + ] + auth_provider_class = getattr( + importlib.import_module(auth_provider_module_name), auth_provider_classname + ) + + self._auth_provider = auth_provider_class(app.config["WEBDRIVER_AUTH_FUNC"]) + + @property + def instance(self) -> MachineAuthProvider: + return self._auth_provider # type: ignore diff --git a/superset/utils/screenshots.py b/superset/utils/screenshots.py index 68cfd9c08c0fa..9ac2b805a6c8e 100644 --- a/superset/utils/screenshots.py +++ b/superset/utils/screenshots.py @@ -15,23 +15,13 @@ # specific language governing permissions and limitations # under the License. import logging -import time from io import BytesIO -from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Optional, TYPE_CHECKING, Union -from flask import current_app, request, Response, session -from flask_login import login_user -from retry.api import retry_call -from selenium.common.exceptions import TimeoutException, WebDriverException -from selenium.webdriver import chrome, firefox -from selenium.webdriver.common.by import By -from selenium.webdriver.remote.webdriver import WebDriver -from selenium.webdriver.support import expected_conditions as EC -from selenium.webdriver.support.ui import WebDriverWait -from werkzeug.http import parse_cookie +from flask import current_app from superset.utils.hashing import md5_sha_from_dict -from superset.utils.urls import headless_url +from superset.utils.webdriver import WebDriverProxy, WindowSize logger = logging.getLogger(__name__) @@ -45,140 +35,6 @@ from flask_appbuilder.security.sqla.models import User from flask_caching import Cache -# Time in seconds, we will wait for the page to load and render -SELENIUM_CHECK_INTERVAL = 2 -SELENIUM_RETRIES = 5 -SELENIUM_HEADSTART = 3 - -WindowSize = Tuple[int, int] - - -def get_auth_cookies(user: "User") -> List[Dict[Any, Any]]: - # Login with the user specified to get the reports - with current_app.test_request_context("/login"): - login_user(user) - # A mock response object to get the cookie information from - response = Response() - current_app.session_interface.save_session(current_app, session, response) - - cookies = [] - - # Set the cookies in the driver - for name, value in response.headers: - if name.lower() == "set-cookie": - cookie = parse_cookie(value) - cookies.append(cookie["session"]) - return cookies - - -def auth_driver(driver: WebDriver, user: "User") -> WebDriver: - """ - Default AuthDriverFuncType type that sets a session cookie flask-login style - :return: WebDriver - """ - if user: - # Set the cookies in the driver - for cookie in get_auth_cookies(user): - info = dict(name="session", value=cookie) - driver.add_cookie(info) - elif request.cookies: - cookies = request.cookies - for k, v in cookies.items(): - cookie = dict(name=k, value=v) - driver.add_cookie(cookie) - return driver - - -class AuthWebDriverProxy: - def __init__( - self, - driver_type: str, - window: Optional[WindowSize] = None, - auth_func: Optional[ - Callable[..., Any] - ] = None, # pylint: disable=bad-whitespace - ): - self._driver_type = driver_type - self._window: WindowSize = window or (800, 600) - config_auth_func = current_app.config.get("WEBDRIVER_AUTH_FUNC", auth_driver) - self._auth_func = auth_func or config_auth_func - - def create(self) -> WebDriver: - if self._driver_type == "firefox": - driver_class = firefox.webdriver.WebDriver - options = firefox.options.Options() - elif self._driver_type == "chrome": - driver_class = chrome.webdriver.WebDriver - options = chrome.options.Options() - arg: str = f"--window-size={self._window[0]},{self._window[1]}" - options.add_argument(arg) - # TODO: 2 lines attempting retina PPI don't seem to be working - options.add_argument("--force-device-scale-factor=2.0") - options.add_argument("--high-dpi-support=2.0") - else: - raise Exception(f"Webdriver name ({self._driver_type}) not supported") - # Prepare args for the webdriver init - options.add_argument("--headless") - kwargs: Dict[Any, Any] = dict(options=options) - kwargs.update(current_app.config["WEBDRIVER_CONFIGURATION"]) - logger.info("Init selenium driver") - return driver_class(**kwargs) - - def auth(self, user: "User") -> WebDriver: - # Setting cookies requires doing a request first - driver = self.create() - driver.get(headless_url("/login/")) - return self._auth_func(driver, user) - - @staticmethod - def destroy(driver: WebDriver, tries: int = 2) -> None: - """Destroy a driver""" - # This is some very flaky code in selenium. Hence the retries - # and catch-all exceptions - try: - retry_call(driver.close, tries=tries) - except Exception: # pylint: disable=broad-except - pass - try: - driver.quit() - except Exception: # pylint: disable=broad-except - pass - - def get_screenshot( - self, - url: str, - element_name: str, - user: "User", - retries: int = SELENIUM_RETRIES, - ) -> Optional[bytes]: - driver = self.auth(user) - driver.set_window_size(*self._window) - driver.get(url) - img: Optional[bytes] = None - logger.debug("Sleeping for %i seconds", SELENIUM_HEADSTART) - time.sleep(SELENIUM_HEADSTART) - try: - logger.debug("Wait for the presence of %s", element_name) - element = WebDriverWait( - driver, current_app.config["SCREENSHOT_LOCATE_WAIT"] - ).until(EC.presence_of_element_located((By.CLASS_NAME, element_name))) - logger.debug("Wait for .loading to be done") - WebDriverWait(driver, current_app.config["SCREENSHOT_LOAD_WAIT"]).until_not( - EC.presence_of_all_elements_located((By.CLASS_NAME, "loading")) - ) - logger.info("Taking a PNG screenshot") - img = element.screenshot_as_png - except TimeoutException: - logger.error("Selenium timed out") - except WebDriverException as ex: - logger.error(ex) - # Some webdrivers do not support screenshots for elements. - # In such cases, take a screenshot of the entire page. - img = driver.screenshot() # pylint: disable=no-member - finally: - self.destroy(driver, retries) - return img - class BaseScreenshot: driver_type = current_app.config.get("EMAIL_REPORTS_WEBDRIVER", "chrome") @@ -192,9 +48,9 @@ def __init__(self, url: str, digest: str): self.url = url self.screenshot: Optional[bytes] = None - def driver(self, window_size: Optional[WindowSize] = None) -> AuthWebDriverProxy: + def driver(self, window_size: Optional[WindowSize] = None) -> WebDriverProxy: window_size = window_size or self.window_size - return AuthWebDriverProxy(self.driver_type, window_size) + return WebDriverProxy(self.driver_type, window_size) def cache_key( self, diff --git a/superset/utils/webdriver.py b/superset/utils/webdriver.py new file mode 100644 index 0000000000000..cb8527c47b25a --- /dev/null +++ b/superset/utils/webdriver.py @@ -0,0 +1,131 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +import time +from typing import Any, Dict, Optional, Tuple, TYPE_CHECKING + +from flask import current_app +from retry.api import retry_call +from selenium.common.exceptions import TimeoutException, WebDriverException +from selenium.webdriver import chrome, firefox +from selenium.webdriver.common.by import By +from selenium.webdriver.remote.webdriver import WebDriver +from selenium.webdriver.support import expected_conditions as EC +from selenium.webdriver.support.ui import WebDriverWait + +from superset.extensions import machine_auth_provider_factory + +WindowSize = Tuple[int, int] +logger = logging.getLogger(__name__) + +# Time in seconds, we will wait for the page to load and render +SELENIUM_CHECK_INTERVAL = 2 +SELENIUM_RETRIES = 5 +SELENIUM_HEADSTART = 3 + + +if TYPE_CHECKING: + # pylint: disable=unused-import + from flask_appbuilder.security.sqla.models import User + + +class WebDriverProxy: + def __init__( + self, driver_type: str, window: Optional[WindowSize] = None, + ): + self._driver_type = driver_type + self._window: WindowSize = window or (800, 600) + self._screenshot_locate_wait = current_app.config["SCREENSHOT_LOCATE_WAIT"] + self._screenshot_load_wait = current_app.config["SCREENSHOT_LOAD_WAIT"] + + def create(self) -> WebDriver: + if self._driver_type == "firefox": + driver_class = firefox.webdriver.WebDriver + options = firefox.options.Options() + elif self._driver_type == "chrome": + driver_class = chrome.webdriver.WebDriver + options = chrome.options.Options() + options.add_argument(f"--window-size={self._window[0]},{self._window[1]}") + else: + raise Exception(f"Webdriver name ({self._driver_type}) not supported") + # Prepare args for the webdriver init + + # Add additional configured options + for arg in current_app.config["WEBDRIVER_OPTION_ARGS"]: + options.add_argument(arg) + + kwargs: Dict[Any, Any] = dict(options=options) + kwargs.update(current_app.config["WEBDRIVER_CONFIGURATION"]) + logger.info("Init selenium driver") + + return driver_class(**kwargs) + + def auth(self, user: "User") -> WebDriver: + driver = self.create() + return machine_auth_provider_factory.instance.authenticate_webdriver( + driver, user + ) + + @staticmethod + def destroy(driver: WebDriver, tries: int = 2) -> None: + """Destroy a driver""" + # This is some very flaky code in selenium. Hence the retries + # and catch-all exceptions + try: + retry_call(driver.close, tries=tries) + except Exception: # pylint: disable=broad-except + pass + try: + driver.quit() + except Exception: # pylint: disable=broad-except + pass + + def get_screenshot( + self, + url: str, + element_name: str, + user: "User", + retries: int = SELENIUM_RETRIES, + ) -> Optional[bytes]: + driver = self.auth(user) + driver.set_window_size(*self._window) + driver.get(url) + img: Optional[bytes] = None + logger.debug("Sleeping for %i seconds", SELENIUM_HEADSTART) + time.sleep(SELENIUM_HEADSTART) + try: + logger.debug("Wait for the presence of %s", element_name) + element = WebDriverWait(driver, self._screenshot_locate_wait).until( + EC.presence_of_element_located((By.CLASS_NAME, element_name)) + ) + logger.debug("Wait for .loading to be done") + WebDriverWait(driver, self._screenshot_load_wait).until_not( + EC.presence_of_all_elements_located((By.CLASS_NAME, "loading")) + ) + logger.info("Taking a PNG screenshot or url %s", url) + img = element.screenshot_as_png + except TimeoutException: + logger.error("Selenium timed out requesting url %s", url) + except WebDriverException as ex: + logger.error(ex) + # Some webdrivers do not support screenshots for elements. + # In such cases, take a screenshot of the entire page. + img = driver.screenshot() # pylint: disable=no-member + finally: + self.destroy(driver, retries) + return img diff --git a/tests/base_tests.py b/tests/base_tests.py index 8f708a5df6dd2..8448e08c55841 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -100,6 +100,7 @@ def create_user_with_roles(username: str, roles: List[str]): assert user_to_create user_to_create.roles = [security_manager.find_role(r) for r in roles] db.session.commit() + return user_to_create @staticmethod def create_user( diff --git a/tests/schedules_test.py b/tests/schedules_test.py index 549a0cd98e773..77f70703c3a01 100644 --- a/tests/schedules_test.py +++ b/tests/schedules_test.py @@ -40,8 +40,7 @@ ) from superset.models.slice import Slice from tests.base_tests import SupersetTestCase - -from .utils import read_fixture +from tests.utils import read_fixture class TestSchedules(SupersetTestCase): @@ -172,7 +171,6 @@ def test_create_driver(self, mock_driver_class): mock_driver_class.return_value = mock_driver mock_driver.find_elements_by_id.side_effect = [True, False] - create_webdriver() create_webdriver() mock_driver.add_cookie.assert_called_once() diff --git a/tests/thumbnails_tests.py b/tests/thumbnails_tests.py index 36126e5adb7cf..fb1fd689aa909 100644 --- a/tests/thumbnails_tests.py +++ b/tests/thumbnails_tests.py @@ -16,7 +16,6 @@ # under the License. # from superset import db # from superset.models.dashboard import Dashboard -import subprocess import urllib.request from unittest import skipUnless from unittest.mock import patch @@ -24,15 +23,11 @@ from flask_testing import LiveServerTestCase from sqlalchemy.sql import func -import tests.test_app from superset import db, is_feature_enabled, security_manager, thumbnail_cache +from superset.extensions import machine_auth_provider_factory from superset.models.dashboard import Dashboard from superset.models.slice import Slice -from superset.utils.screenshots import ( - ChartScreenshot, - DashboardScreenshot, - get_auth_cookies, -) +from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot from superset.utils.urls import get_url_path from tests.test_app import app @@ -45,10 +40,7 @@ def create_app(self): def url_open_auth(self, username: str, url: str): admin_user = security_manager.find_user(username=username) - cookies = {} - for cookie in get_auth_cookies(admin_user): - cookies["session"] = cookie - + cookies = machine_auth_provider_factory.instance.get_auth_cookies(admin_user) opener = urllib.request.build_opener() opener.addheaders.append(("Cookie", f"session={cookies['session']}")) return opener.open(f"{self.get_server_url()}/{url}") diff --git a/tests/util/__init__.py b/tests/util/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/tests/util/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/util/machine_auth_tests.py b/tests/util/machine_auth_tests.py new file mode 100644 index 0000000000000..1bc08e8eb5d6c --- /dev/null +++ b/tests/util/machine_auth_tests.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest.mock import call, Mock, patch + +from superset.extensions import machine_auth_provider_factory +from tests.base_tests import SupersetTestCase + + +class MachineAuthProviderTests(SupersetTestCase): + def test_get_auth_cookies(self): + user = self.get_user("admin") + auth_cookies = machine_auth_provider_factory.instance.get_auth_cookies(user) + self.assertIsNotNone(auth_cookies["session"]) + + @patch("superset.utils.machine_auth.MachineAuthProvider.get_auth_cookies") + def test_auth_driver_user(self, get_auth_cookies): + user = self.get_user("admin") + driver = Mock() + get_auth_cookies.return_value = { + "session": "session_val", + "other_cookie": "other_val", + } + machine_auth_provider_factory.instance.authenticate_webdriver(driver, user) + driver.add_cookie.assert_has_calls( + [ + call({"name": "session", "value": "session_val"}), + call({"name": "other_cookie", "value": "other_val"}), + ] + ) + + @patch("superset.utils.machine_auth.request") + def test_auth_driver_request(self, request): + driver = Mock() + request.cookies = {"session": "session_val", "other_cookie": "other_val"} + machine_auth_provider_factory.instance.authenticate_webdriver(driver, None) + driver.add_cookie.assert_has_calls( + [ + call({"name": "session", "value": "session_val"}), + call({"name": "other_cookie", "value": "other_val"}), + ] + )