Skip to content

Commit

Permalink
feat: add event_logger to test_connection and create_database commands (
Browse files Browse the repository at this point in the history
#13468)

Co-authored-by: Beto Dealmeida <[email protected]>
  • Loading branch information
hughhhh and betodealmeida authored Mar 9, 2021
1 parent 9b8e255 commit c91c455
Show file tree
Hide file tree
Showing 6 changed files with 136 additions and 21 deletions.
16 changes: 13 additions & 3 deletions superset/databases/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
from superset.databases.dao import DatabaseDAO
from superset.extensions import db, security_manager
from superset.extensions import db, event_logger, security_manager

logger = logging.getLogger(__name__)

Expand All @@ -50,8 +50,12 @@ def run(self) -> Model:

try:
TestConnectionDatabaseCommand(self._actor, self._properties).run()
except Exception:
except Exception as ex: # pylint: disable=broad-except
db.session.rollback()
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
engine=database.db_engine_spec.__name__,
)
raise DatabaseConnectionFailedError()

# adding a new database we always want to force refresh schema list
Expand All @@ -63,7 +67,10 @@ def run(self) -> Model:
security_manager.add_permission_view_menu("database_access", database.perm)
db.session.commit()
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",
engine=database.db_engine_spec.__name__,
)
raise DatabaseCreateFailedError()
return database

Expand All @@ -84,4 +91,7 @@ def validate(self) -> None:
if exceptions:
exception = DatabaseInvalidError()
exception.add_list(exceptions)
event_logger.log_with_context(
action=f"db_connection_failed.{exception.__class__.__name__}"
)
raise exception
40 changes: 29 additions & 11 deletions superset/databases/commands/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

from flask_appbuilder.security.sqla.models import User
from flask_babel import gettext as _
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import DBAPIError, NoSuchModuleError

from superset.commands.base import BaseCommand
Expand All @@ -32,6 +31,7 @@
)
from superset.databases.dao import DatabaseDAO
from superset.exceptions import SupersetSecurityException
from superset.extensions import event_logger
from superset.models.core import Database

logger = logging.getLogger(__name__)
Expand All @@ -55,24 +55,42 @@ def run(self) -> None:
impersonate_user=self._properties.get("impersonate_user", False),
encrypted_extra=self._properties.get("encrypted_extra", "{}"),
)
if database is not None:
database.set_sqlalchemy_uri(uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
username = self._actor.username if self._actor is not None else None
engine = database.get_sqla_engine(user_name=username)

database.set_sqlalchemy_uri(uri)
database.db_engine_spec.mutate_db_for_connection_test(database)
username = self._actor.username if self._actor is not None else None
engine = database.get_sqla_engine(user_name=username)
with closing(engine.raw_connection()) as conn:
if not engine.dialect.do_ping(conn):
raise DBAPIError(None, None, None)
except (NoSuchModuleError, ModuleNotFoundError):
driver_name = make_url(uri).drivername

except (NoSuchModuleError, ModuleNotFoundError) as ex:
event_logger.log_with_context(
action=f"test_connection_error.{ex.__class__.__name__}",
engine=database.db_engine_spec.__name__,
)
raise DatabaseTestConnectionDriverError(
message=_("Could not load database driver: {}").format(driver_name),
message=_("Could not load database driver: {}").format(
database.db_engine_spec.__name__
),
)
except DBAPIError as ex:
event_logger.log_with_context(
action=f"test_connection_error.{ex.__class__.__name__}",
engine=database.db_engine_spec.__name__,
)
except DBAPIError:
raise DatabaseTestConnectionFailedError()
except SupersetSecurityException as ex:
event_logger.log_with_context(
action=f"test_connection_error.{ex.__class__.__name__}",
engine=database.db_engine_spec.__name__,
)
raise DatabaseSecurityUnsafeError(message=str(ex))
except Exception:
except Exception as ex: # pylint: disable=broad-except
event_logger.log_with_context(
action=f"test_connection_error.{ex.__class__.__name__}",
engine=database.db_engine_spec.__name__,
)
raise DatabaseTestConnectionUnexpectedError()

def validate(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion superset/databases/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_database_by_name(database_name: str) -> Optional[Database]:
@staticmethod
def build_db_for_connection_test(
server_cert: str, extra: str, impersonate_user: bool, encrypted_extra: str
) -> Optional[Database]:
) -> Database:
return Database(
server_cert=server_cert,
extra=extra,
Expand Down
7 changes: 5 additions & 2 deletions superset/utils/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,14 +104,17 @@ def log( # pylint: disable=too-many-arguments
def log_with_context( # pylint: disable=too-many-locals
self,
action: str,
duration: timedelta,
duration: Optional[timedelta] = None,
object_ref: Optional[str] = None,
log_to_statsd: bool = True,
**payload_override: Optional[Dict[str, Any]],
) -> None:
from superset.views.core import get_form_data

referrer = request.referrer[:1000] if request.referrer else None

duration_ms = int(duration.total_seconds() * 1000) if duration else None

try:
user_id = g.user.get_id()
except Exception as ex: # pylint: disable=broad-except
Expand Down Expand Up @@ -158,7 +161,7 @@ def log_with_context( # pylint: disable=too-many-locals
records=records,
dashboard_id=dashboard_id,
slice_id=slice_id,
duration_ms=int(duration.total_seconds() * 1000),
duration_ms=duration_ms,
referrer=referrer,
)

Expand Down
4 changes: 2 additions & 2 deletions tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,7 +819,7 @@ def test_test_connection_failed(self):
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": "Could not load database driver: broken",
"message": "Could not load database driver: BaseEngineSpec",
}
self.assertEqual(response, expected_response)

Expand All @@ -834,7 +834,7 @@ def test_test_connection_failed(self):
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": "Could not load database driver: mssql+pymssql",
"message": "Could not load database driver: MssqlEngineSpec",
}
self.assertEqual(response, expected_response)

Expand Down
88 changes: 86 additions & 2 deletions tests/databases/commands_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,30 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=no-self-use, invalid-name
from unittest import mock
from unittest.mock import patch

import pytest
import yaml
from sqlalchemy.exc import DBAPIError

from superset import db, security_manager
from superset import db, event_logger, security_manager
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.exceptions import IncorrectVersionError
from superset.connectors.sqla.models import SqlaTable
from superset.databases.commands.exceptions import DatabaseNotFoundError
from superset.databases.commands.exceptions import (
DatabaseNotFoundError,
DatabaseSecurityUnsafeError,
DatabaseTestConnectionDriverError,
DatabaseTestConnectionFailedError,
DatabaseTestConnectionUnexpectedError,
)
from superset.databases.commands.export import ExportDatabasesCommand
from superset.databases.commands.importers.v1 import ImportDatabasesCommand
from superset.databases.commands.test_connection import TestConnectionDatabaseCommand
from superset.databases.schemas import DatabaseTestConnectionSchema
from superset.errors import SupersetError
from superset.exceptions import SupersetSecurityException
from superset.models.core import Database
from superset.utils.core import backend, get_example_database
from tests.base_tests import SupersetTestCase
Expand Down Expand Up @@ -508,3 +520,75 @@ def test_import_v1_rollback(self, mock_import_dataset):
# verify that the database was not added
new_num_databases = db.session.query(Database).count()
assert new_num_databases == num_databases


class TestTestConnectionDatabaseCommand(SupersetTestCase):
@mock.patch("superset.databases.dao.Database.get_sqla_engine")
@mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context"
)
def test_connection_db_exception(self, mock_event_logger, mock_get_sqla_engine):
"""Test to make sure event_logger is called when an exception is raised"""
database = get_example_database()
mock_get_sqla_engine.side_effect = Exception("An error has occurred!")
db_uri = database.sqlalchemy_uri_decrypted
json_payload = {"sqlalchemy_uri": db_uri}
command_without_db_name = TestConnectionDatabaseCommand(
security_manager.find_user("admin"), json_payload
)

with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo:
command_without_db_name.run()
assert str(excinfo.value) == (
"Unexpected error occurred, please check your logs for details"
)
mock_event_logger.assert_called()

@mock.patch("superset.databases.dao.Database.get_sqla_engine")
@mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context"
)
def test_connection_superset_security_connection(
self, mock_event_logger, mock_get_sqla_engine
):
"""Test to make sure event_logger is called when security
connection exc is raised"""
database = get_example_database()
mock_get_sqla_engine.side_effect = SupersetSecurityException(
SupersetError(error_type=500, message="test", level="info", extra={})
)
db_uri = database.sqlalchemy_uri_decrypted
json_payload = {"sqlalchemy_uri": db_uri}
command_without_db_name = TestConnectionDatabaseCommand(
security_manager.find_user("admin"), json_payload
)

with pytest.raises(DatabaseSecurityUnsafeError) as excinfo:
command_without_db_name.run()
assert str(excinfo.value) == ("Stopped an unsafe database connection")

mock_event_logger.assert_called()

@mock.patch("superset.databases.dao.Database.get_sqla_engine")
@mock.patch(
"superset.databases.commands.test_connection.event_logger.log_with_context"
)
def test_connection_db_api_exc(self, mock_event_logger, mock_get_sqla_engine):
"""Test to make sure event_logger is called when DBAPIError is raised"""
database = get_example_database()
mock_get_sqla_engine.side_effect = DBAPIError(
statement="error", params={}, orig={}
)
db_uri = database.sqlalchemy_uri_decrypted
json_payload = {"sqlalchemy_uri": db_uri}
command_without_db_name = TestConnectionDatabaseCommand(
security_manager.find_user("admin"), json_payload
)

with pytest.raises(DatabaseTestConnectionFailedError) as excinfo:
command_without_db_name.run()
assert str(excinfo.value) == (
"Connection failed, please check your connection settings"
)

mock_event_logger.assert_called()

0 comments on commit c91c455

Please sign in to comment.