diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 89c640b7a6eab..b9cc817402713 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -960,13 +960,13 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: if self.fetch_values_predicate: qry = qry.where(self.get_fetch_values_predicate()) - engine = self.database.get_sqla_engine() - sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) - sql = self._apply_cte(sql, cte) - sql = self.mutate_query_from_config(sql) + with self.database.get_sqla_engine_with_context() as engine: + sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) + sql = self._apply_cte(sql, cte) + sql = self.mutate_query_from_config(sql) - df = pd.read_sql_query(sql=sql, con=engine) - return df[column_name].to_list() + df = pd.read_sql_query(sql=sql, con=engine) + return df[column_name].to_list() def mutate_query_from_config(self, sql: str) -> str: """Apply config's SQL_QUERY_MUTATOR diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 8151bfd44b03b..05cf8cea13249 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -112,7 +112,6 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]: ) db_engine_spec = dataset.database.db_engine_spec - engine = dataset.database.get_sqla_engine(schema=dataset.schema) sql = dataset.get_template_processor().process_template( dataset.sql, **dataset.template_params_dict ) @@ -137,13 +136,18 @@ def get_virtual_table_metadata(dataset: SqlaTable) -> List[ResultSetColumnType]: # TODO(villebro): refactor to use same code that's used by # sql_lab.py:execute_sql_statements try: - with closing(engine.raw_connection()) as conn: - cursor = conn.cursor() - query = dataset.database.apply_limit_to_sql(statements[0], limit=1) - db_engine_spec.execute(cursor, query) - result = db_engine_spec.fetch_data(cursor, limit=1) - result_set = SupersetResultSet(result, cursor.description, db_engine_spec) - cols = result_set.columns + with dataset.database.get_sqla_engine_with_context( + schema=dataset.schema + ) as engine: + with closing(engine.raw_connection()) as conn: + cursor = conn.cursor() + query = dataset.database.apply_limit_to_sql(statements[0], limit=1) + db_engine_spec.execute(cursor, query) + result = db_engine_spec.fetch_data(cursor, limit=1) + result_set = SupersetResultSet( + result, cursor.description, db_engine_spec + ) + cols = result_set.columns except Exception as ex: raise SupersetGenericDBErrorException(message=str(ex)) from ex return cols @@ -155,14 +159,17 @@ def get_columns_description( ) -> List[ResultSetColumnType]: db_engine_spec = database.db_engine_spec try: - with closing(database.get_sqla_engine().raw_connection()) as conn: - cursor = conn.cursor() - query = database.apply_limit_to_sql(query, limit=1) - cursor.execute(query) - db_engine_spec.execute(cursor, query) - result = db_engine_spec.fetch_data(cursor, limit=1) - result_set = SupersetResultSet(result, cursor.description, db_engine_spec) - return result_set.columns + with database.get_sqla_engine_with_context() as engine: + with closing(engine.raw_connection()) as conn: + cursor = conn.cursor() + query = database.apply_limit_to_sql(query, limit=1) + cursor.execute(query) + db_engine_spec.execute(cursor, query) + result = db_engine_spec.fetch_data(cursor, limit=1) + result_set = SupersetResultSet( + result, cursor.description, db_engine_spec + ) + return result_set.columns except Exception as ex: raise SupersetGenericDBErrorException(message=str(ex)) from ex diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 167b5657fbfdc..3393be67b7f83 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -90,7 +90,6 @@ def run(self) -> None: # pylint: disable=too-many-statements database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) - engine = database.get_sqla_engine() event_logger.log_with_context( action="test_connection_attempt", engine=database.db_engine_spec.__name__, @@ -100,31 +99,32 @@ def ping(engine: Engine) -> bool: with closing(engine.raw_connection()) as conn: return engine.dialect.do_ping(conn) - try: - alive = func_timeout( - int(app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds()), - ping, - args=(engine,), - ) - except (sqlite3.ProgrammingError, RuntimeError): - # SQLite can't run on a separate thread, so ``func_timeout`` fails - # RuntimeError catches the equivalent error from duckdb. - alive = engine.dialect.do_ping(engine) - except FunctionTimedOut as ex: - raise SupersetTimeoutException( - error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT, - message=( - "Please check your connection details and database settings, " - "and ensure that your database is accepting connections, " - "then try connecting again." - ), - level=ErrorLevel.ERROR, - extra={"sqlalchemy_uri": database.sqlalchemy_uri}, - ) from ex - except Exception as ex: # pylint: disable=broad-except - alive = False - # So we stop losing the original message if any - ex_str = str(ex) + with database.get_sqla_engine_with_context() as engine: + try: + alive = func_timeout( + app.config["TEST_DATABASE_CONNECTION_TIMEOUT"].total_seconds(), + ping, + args=(engine,), + ) + except (sqlite3.ProgrammingError, RuntimeError): + # SQLite can't run on a separate thread, so ``func_timeout`` fails + # RuntimeError catches the equivalent error from duckdb. + alive = engine.dialect.do_ping(engine) + except FunctionTimedOut as ex: + raise SupersetTimeoutException( + error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT, + message=( + "Please check your connection details and database settings, " + "and ensure that your database is accepting connections, " + "then try connecting again." + ), + level=ErrorLevel.ERROR, + extra={"sqlalchemy_uri": database.sqlalchemy_uri}, + ) from ex + except Exception as ex: # pylint: disable=broad-except + alive = False + # So we stop losing the original message if any + ex_str = str(ex) if not alive: raise DBAPIError(ex_str or None, None, None) diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py index a8956257fa28a..8c58ef5de0bfb 100644 --- a/superset/databases/commands/validate.py +++ b/superset/databases/commands/validate.py @@ -101,21 +101,22 @@ def run(self) -> None: database.set_sqlalchemy_uri(sqlalchemy_uri) database.db_engine_spec.mutate_db_for_connection_test(database) - engine = database.get_sqla_engine() - try: - with closing(engine.raw_connection()) as conn: - alive = engine.dialect.do_ping(conn) - except Exception as ex: - url = make_url_safe(sqlalchemy_uri) - context = { - "hostname": url.host, - "password": url.password, - "port": url.port, - "username": url.username, - "database": url.database, - } - errors = database.db_engine_spec.extract_errors(ex, context) - raise DatabaseTestConnectionFailedError(errors) from ex + alive = False + with database.get_sqla_engine_with_context() as engine: + try: + with closing(engine.raw_connection()) as conn: + alive = engine.dialect.do_ping(conn) + except Exception as ex: + url = make_url_safe(sqlalchemy_uri) + context = { + "hostname": url.host, + "password": url.password, + "port": url.port, + "username": url.username, + "database": url.database, + } + errors = database.db_engine_spec.extract_errors(ex, context) + raise DatabaseTestConnectionFailedError(errors) from ex if not alive: raise DatabaseOfflineError( diff --git a/superset/datasets/commands/importers/v1/utils.py b/superset/datasets/commands/importers/v1/utils.py index ba2b7df26174a..d04763c7a8996 100644 --- a/superset/datasets/commands/importers/v1/utils.py +++ b/superset/datasets/commands/importers/v1/utils.py @@ -166,17 +166,26 @@ def load_data( if database.sqlalchemy_uri == current_app.config.get("SQLALCHEMY_DATABASE_URI"): logger.info("Loading data inside the import transaction") connection = session.connection() + df.to_sql( + dataset.table_name, + con=connection, + schema=dataset.schema, + if_exists="replace", + chunksize=CHUNKSIZE, + dtype=dtype, + index=False, + method="multi", + ) else: logger.warning("Loading data outside the import transaction") - connection = database.get_sqla_engine() - - df.to_sql( - dataset.table_name, - con=connection, - schema=dataset.schema, - if_exists="replace", - chunksize=CHUNKSIZE, - dtype=dtype, - index=False, - method="multi", - ) + with database.get_sqla_engine_with_context() as engine: + df.to_sql( + dataset.table_name, + con=engine, + schema=dataset.schema, + if_exists="replace", + chunksize=CHUNKSIZE, + dtype=dtype, + index=False, + method="multi", + ) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 2a1363e0b6957..1781ef7ef453d 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -23,6 +23,7 @@ from typing import ( Any, Callable, + ContextManager, Dict, List, Match, @@ -471,8 +472,16 @@ def get_engine( database: "Database", schema: Optional[str] = None, source: Optional[utils.QuerySource] = None, - ) -> Engine: - return database.get_sqla_engine(schema=schema, source=source) + ) -> ContextManager[Engine]: + """ + Return an engine context manager. + + >>> with DBEngineSpec.get_engine(database, schema, source) as engine: + ... connection = engine.connect() + ... connection.execute(sql) + + """ + return database.get_sqla_engine_with_context(schema=schema, source=source) @classmethod def get_timestamp_expr( @@ -894,17 +903,17 @@ def df_to_sql( :param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method """ - engine = cls.get_engine(database) to_sql_kwargs["name"] = table.table if table.schema: # Only add schema when it is preset and non empty. to_sql_kwargs["schema"] = table.schema - if engine.dialect.supports_multivalues_insert: - to_sql_kwargs["method"] = "multi" + with cls.get_engine(database) as engine: + if engine.dialect.supports_multivalues_insert: + to_sql_kwargs["method"] = "multi" - df.to_sql(con=engine, **to_sql_kwargs) + df.to_sql(con=engine, **to_sql_kwargs) @classmethod def convert_dttm( # pylint: disable=unused-argument @@ -1277,13 +1286,15 @@ def estimate_query_cost( parsed_query = sql_parse.ParsedQuery(sql) statements = parsed_query.get_statements() - engine = cls.get_engine(database, schema=schema, source=source) costs = [] - with closing(engine.raw_connection()) as conn: - cursor = conn.cursor() - for statement in statements: - processed_statement = cls.process_statement(statement, database) - costs.append(cls.estimate_statement_cost(processed_statement, cursor)) + with cls.get_engine(database, schema=schema, source=source) as engine: + with closing(engine.raw_connection()) as conn: + cursor = conn.cursor() + for statement in statements: + processed_statement = cls.process_statement(statement, database) + costs.append( + cls.estimate_statement_cost(processed_statement, cursor) + ) return costs @classmethod diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index ffeff31b17787..c689e30ec0f80 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -340,8 +340,12 @@ def df_to_sql( if not table.schema: raise Exception("The table schema must be defined") - engine = cls.get_engine(database) - to_gbq_kwargs = {"destination_table": str(table), "project_id": engine.url.host} + to_gbq_kwargs = {} + with cls.get_engine(database) as engine: + to_gbq_kwargs = { + "destination_table": str(table), + "project_id": engine.url.host, + } # Add credentials if they are set on the SQLAlchemy dialect. creds = engine.dialect.credentials_info diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index fd1a2754d76bc..805a7ee400cfd 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -109,11 +109,11 @@ def extra_table_metadata( table_name: str, schema_name: Optional[str], ) -> Dict[str, Any]: - engine = cls.get_engine(database, schema=schema_name) - with closing(engine.raw_connection()) as conn: - cursor = conn.cursor() - cursor.execute(f'SELECT GET_METADATA("{table_name}")') - results = cursor.fetchone()[0] + with cls.get_engine(database, schema=schema_name) as engine: + with closing(engine.raw_connection()) as conn: + cursor = conn.cursor() + cursor.execute(f'SELECT GET_METADATA("{table_name}")') + results = cursor.fetchone()[0] try: metadata = json.loads(results) diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index b37348e911ece..3c541c357ea59 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -185,8 +185,6 @@ def df_to_sql( :param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method """ - engine = cls.get_engine(database) - if to_sql_kwargs["if_exists"] == "append": raise SupersetException("Append operation not currently supported") @@ -205,7 +203,8 @@ def df_to_sql( if table_exists: raise SupersetException("Table already exists") elif to_sql_kwargs["if_exists"] == "replace": - engine.execute(f"DROP TABLE IF EXISTS {str(table)}") + with cls.get_engine(database) as engine: + engine.execute(f"DROP TABLE IF EXISTS {str(table)}") def _get_hive_type(dtype: np.dtype) -> str: hive_type_by_dtype = { diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index e959eb219506a..b513db0a61958 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -462,12 +462,11 @@ def get_view_names( ).strip() params = {} - engine = cls.get_engine(database, schema=schema) - - with closing(engine.raw_connection()) as conn: - cursor = conn.cursor() - cursor.execute(sql, params) - results = cursor.fetchall() + with cls.get_engine(database, schema=schema) as engine: + with closing(engine.raw_connection()) as conn: + cursor = conn.cursor() + cursor.execute(sql, params) + results = cursor.fetchall() return sorted([row[0] for row in results]) @@ -989,17 +988,17 @@ def get_create_view( # pylint: disable=import-outside-toplevel from pyhive.exc import DatabaseError - engine = cls.get_engine(database, schema) - with closing(engine.raw_connection()) as conn: - cursor = conn.cursor() - sql = f"SHOW CREATE VIEW {schema}.{table}" - try: - cls.execute(cursor, sql) - - except DatabaseError: # not a VIEW - return None - rows = cls.fetch_data(cursor, 1) - return rows[0][0] + with cls.get_engine(database, schema=schema) as engine: + with closing(engine.raw_connection()) as conn: + cursor = conn.cursor() + sql = f"SHOW CREATE VIEW {schema}.{table}" + try: + cls.execute(cursor, sql) + + except DatabaseError: # not a VIEW + return None + rows = cls.fetch_data(cursor, 1) + return rows[0][0] @classmethod def get_tracking_url(cls, cursor: "Cursor") -> Optional[str]: diff --git a/superset/examples/bart_lines.py b/superset/examples/bart_lines.py index 91257058be75a..5d167b02d0627 100644 --- a/superset/examples/bart_lines.py +++ b/superset/examples/bart_lines.py @@ -29,31 +29,31 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None: tbl_name = "bart_lines" database = get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - url = get_example_url("bart-lines.json.gz") - df = pd.read_json(url, encoding="latin-1", compression="gzip") - df["path_json"] = df.path.map(json.dumps) - df["polyline"] = df.path.map(polyline.encode) - del df["path"] + if not only_metadata and (not table_exists or force): + url = get_example_url("bart-lines.json.gz") + df = pd.read_json(url, encoding="latin-1", compression="gzip") + df["path_json"] = df.path.map(json.dumps) + df["polyline"] = df.path.map(polyline.encode) + del df["path"] - df.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={ - "color": String(255), - "name": String(255), - "polyline": Text, - "path_json": Text, - }, - index=False, - ) + df.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={ + "color": String(255), + "name": String(255), + "polyline": Text, + "path_json": Text, + }, + index=False, + ) print("Creating table {} reference".format(tbl_name)) table = get_table_connector_registry() diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index f8b8a8ecf7ca8..406a70b2cc4d5 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -76,25 +76,25 @@ def load_data(tbl_name: str, database: Database, sample: bool = False) -> None: pdf.ds = pd.to_datetime(pdf.ds, unit="ms") pdf = pdf.head(100) if sample else pdf - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - - pdf.to_sql( - tbl_name, - database.get_sqla_engine(), - schema=schema, - if_exists="replace", - chunksize=500, - dtype={ - # TODO(bkyryliuk): use TIMESTAMP type for presto - "ds": DateTime if database.backend != "presto" else String(255), - "gender": String(16), - "state": String(10), - "name": String(255), - }, - method="multi", - index=False, - ) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + + pdf.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={ + # TODO(bkyryliuk): use TIMESTAMP type for presto + "ds": DateTime if database.backend != "presto" else String(255), + "gender": String(16), + "state": String(10), + "name": String(255), + }, + method="multi", + index=False, + ) print("Done loading table!") print("-" * 80) @@ -104,8 +104,8 @@ def load_birth_names( ) -> None: """Loading birth name dataset from a zip file in the repo""" database = get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name tbl_name = "birth_names" table_exists = database.has_table_by_name(tbl_name, schema=schema) diff --git a/superset/examples/country_map.py b/superset/examples/country_map.py index 302b55180ea84..4331033ca8369 100644 --- a/superset/examples/country_map.py +++ b/superset/examples/country_map.py @@ -39,38 +39,39 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N """Loading data for map with country map""" tbl_name = "birth_france_by_region" database = database_utils.get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - url = get_example_url("birth_france_data_for_country_map.csv") - data = pd.read_csv(url, encoding="utf-8") - data["dttm"] = datetime.datetime.now().date() - data.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={ - "DEPT_ID": String(10), - "2003": BigInteger, - "2004": BigInteger, - "2005": BigInteger, - "2006": BigInteger, - "2007": BigInteger, - "2008": BigInteger, - "2009": BigInteger, - "2010": BigInteger, - "2011": BigInteger, - "2012": BigInteger, - "2013": BigInteger, - "2014": BigInteger, - "dttm": Date(), - }, - index=False, - ) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) + + if not only_metadata and (not table_exists or force): + url = get_example_url("birth_france_data_for_country_map.csv") + data = pd.read_csv(url, encoding="utf-8") + data["dttm"] = datetime.datetime.now().date() + data.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={ + "DEPT_ID": String(10), + "2003": BigInteger, + "2004": BigInteger, + "2005": BigInteger, + "2006": BigInteger, + "2007": BigInteger, + "2008": BigInteger, + "2009": BigInteger, + "2010": BigInteger, + "2011": BigInteger, + "2012": BigInteger, + "2013": BigInteger, + "2014": BigInteger, + "dttm": Date(), + }, + index=False, + ) print("Done loading table!") print("-" * 80) diff --git a/superset/examples/energy.py b/superset/examples/energy.py index 72b22525f2760..6688e5d08844d 100644 --- a/superset/examples/energy.py +++ b/superset/examples/energy.py @@ -41,24 +41,25 @@ def load_energy( """Loads an energy related dataset to use with sankey and graphs""" tbl_name = "energy_usage" database = database_utils.get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - url = get_example_url("energy.json.gz") - pdf = pd.read_json(url, compression="gzip") - pdf = pdf.head(100) if sample else pdf - pdf.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={"source": String(255), "target": String(255), "value": Float()}, - index=False, - method="multi", - ) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) + + if not only_metadata and (not table_exists or force): + url = get_example_url("energy.json.gz") + pdf = pd.read_json(url, compression="gzip") + pdf = pdf.head(100) if sample else pdf + pdf.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={"source": String(255), "target": String(255), "value": Float()}, + index=False, + method="multi", + ) print("Creating table [wb_health_population] reference") table = get_table_connector_registry() diff --git a/superset/examples/flights.py b/superset/examples/flights.py index 1389c65c9a901..7c8f9802988bd 100644 --- a/superset/examples/flights.py +++ b/superset/examples/flights.py @@ -27,35 +27,37 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None: """Loading random time series data from a zip file in the repo""" tbl_name = "flights" database = database_utils.get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - flight_data_url = get_example_url("flight_data.csv.gz") - pdf = pd.read_csv(flight_data_url, encoding="latin-1", compression="gzip") + if not only_metadata and (not table_exists or force): + flight_data_url = get_example_url("flight_data.csv.gz") + pdf = pd.read_csv(flight_data_url, encoding="latin-1", compression="gzip") - # Loading airports info to join and get lat/long - airports_url = get_example_url("airports.csv.gz") - airports = pd.read_csv(airports_url, encoding="latin-1", compression="gzip") - airports = airports.set_index("IATA_CODE") + # Loading airports info to join and get lat/long + airports_url = get_example_url("airports.csv.gz") + airports = pd.read_csv(airports_url, encoding="latin-1", compression="gzip") + airports = airports.set_index("IATA_CODE") - pdf[ # pylint: disable=unsupported-assignment-operation,useless-suppression - "ds" - ] = (pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" + pdf.DAY.map(str)) - pdf.ds = pd.to_datetime(pdf.ds) - pdf.drop(columns=["DAY", "MONTH", "YEAR"]) - pdf = pdf.join(airports, on="ORIGIN_AIRPORT", rsuffix="_ORIG") - pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST") - pdf.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={"ds": DateTime}, - index=False, - ) + pdf[ # pylint: disable=unsupported-assignment-operation,useless-suppression + "ds" + ] = ( + pdf.YEAR.map(str) + "-0" + pdf.MONTH.map(str) + "-0" + pdf.DAY.map(str) + ) + pdf.ds = pd.to_datetime(pdf.ds) + pdf.drop(columns=["DAY", "MONTH", "YEAR"]) + pdf = pdf.join(airports, on="ORIGIN_AIRPORT", rsuffix="_ORIG") + pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST") + pdf.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={"ds": DateTime}, + index=False, + ) table = get_table_connector_registry() tbl = db.session.query(table).filter_by(table_name=tbl_name).first() diff --git a/superset/examples/long_lat.py b/superset/examples/long_lat.py index 76f51a615951f..88b45548f48dc 100644 --- a/superset/examples/long_lat.py +++ b/superset/examples/long_lat.py @@ -39,49 +39,51 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None """Loading lat/long data from a csv file in the repo""" tbl_name = "long_lat" database = database_utils.get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - url = get_example_url("san_francisco.csv.gz") - pdf = pd.read_csv(url, encoding="utf-8", compression="gzip") - start = datetime.datetime.now().replace( - hour=0, minute=0, second=0, microsecond=0 - ) - pdf["datetime"] = [ - start + datetime.timedelta(hours=i * 24 / (len(pdf) - 1)) - for i in range(len(pdf)) - ] - pdf["occupancy"] = [random.randint(1, 6) for _ in range(len(pdf))] - pdf["radius_miles"] = [random.uniform(1, 3) for _ in range(len(pdf))] - pdf["geohash"] = pdf[["LAT", "LON"]].apply(lambda x: geohash.encode(*x), axis=1) - pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str), sep=",") - pdf.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={ - "longitude": Float(), - "latitude": Float(), - "number": Float(), - "street": String(100), - "unit": String(10), - "city": String(50), - "district": String(50), - "region": String(50), - "postcode": Float(), - "id": String(100), - "datetime": DateTime(), - "occupancy": Float(), - "radius_miles": Float(), - "geohash": String(12), - "delimited": String(60), - }, - index=False, - ) + if not only_metadata and (not table_exists or force): + url = get_example_url("san_francisco.csv.gz") + pdf = pd.read_csv(url, encoding="utf-8", compression="gzip") + start = datetime.datetime.now().replace( + hour=0, minute=0, second=0, microsecond=0 + ) + pdf["datetime"] = [ + start + datetime.timedelta(hours=i * 24 / (len(pdf) - 1)) + for i in range(len(pdf)) + ] + pdf["occupancy"] = [random.randint(1, 6) for _ in range(len(pdf))] + pdf["radius_miles"] = [random.uniform(1, 3) for _ in range(len(pdf))] + pdf["geohash"] = pdf[["LAT", "LON"]].apply( + lambda x: geohash.encode(*x), axis=1 + ) + pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str), sep=",") + pdf.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={ + "longitude": Float(), + "latitude": Float(), + "number": Float(), + "street": String(100), + "unit": String(10), + "city": String(50), + "district": String(50), + "region": String(50), + "postcode": Float(), + "id": String(100), + "datetime": DateTime(), + "occupancy": Float(), + "radius_miles": Float(), + "geohash": String(12), + "delimited": String(60), + }, + index=False, + ) print("Done loading table!") print("-" * 80) diff --git a/superset/examples/multiformat_time_series.py b/superset/examples/multiformat_time_series.py index 62e16d2cb0881..b030bcdb0f23c 100644 --- a/superset/examples/multiformat_time_series.py +++ b/superset/examples/multiformat_time_series.py @@ -39,41 +39,41 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals """Loading time series data from a zip file in the repo""" tbl_name = "multiformat_time_series" database = get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - url = get_example_url("multiformat_time_series.json.gz") - pdf = pd.read_json(url, compression="gzip") - # TODO(bkyryliuk): move load examples data into the pytest fixture - if database.backend == "presto": - pdf.ds = pd.to_datetime(pdf.ds, unit="s") - pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d") - pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s") - pdf.ds2 = pdf.ds2.dt.strftime("%Y-%m-%d %H:%M%:%S") - else: - pdf.ds = pd.to_datetime(pdf.ds, unit="s") - pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s") + if not only_metadata and (not table_exists or force): + url = get_example_url("multiformat_time_series.json.gz") + pdf = pd.read_json(url, compression="gzip") + # TODO(bkyryliuk): move load examples data into the pytest fixture + if database.backend == "presto": + pdf.ds = pd.to_datetime(pdf.ds, unit="s") + pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d") + pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s") + pdf.ds2 = pdf.ds2.dt.strftime("%Y-%m-%d %H:%M%:%S") + else: + pdf.ds = pd.to_datetime(pdf.ds, unit="s") + pdf.ds2 = pd.to_datetime(pdf.ds2, unit="s") - pdf.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={ - "ds": String(255) if database.backend == "presto" else Date, - "ds2": String(255) if database.backend == "presto" else DateTime, - "epoch_s": BigInteger, - "epoch_ms": BigInteger, - "string0": String(100), - "string1": String(100), - "string2": String(100), - "string3": String(100), - }, - index=False, - ) + pdf.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={ + "ds": String(255) if database.backend == "presto" else Date, + "ds2": String(255) if database.backend == "presto" else DateTime, + "epoch_s": BigInteger, + "epoch_ms": BigInteger, + "string0": String(100), + "string1": String(100), + "string2": String(100), + "string3": String(100), + }, + index=False, + ) print("Done loading table!") print("-" * 80) diff --git a/superset/examples/paris.py b/superset/examples/paris.py index c323007028523..a54a3706b13c0 100644 --- a/superset/examples/paris.py +++ b/superset/examples/paris.py @@ -28,29 +28,29 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) -> None: tbl_name = "paris_iris_mapping" database = database_utils.get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - url = get_example_url("paris_iris.json.gz") - df = pd.read_json(url, compression="gzip") - df["features"] = df.features.map(json.dumps) + if not only_metadata and (not table_exists or force): + url = get_example_url("paris_iris.json.gz") + df = pd.read_json(url, compression="gzip") + df["features"] = df.features.map(json.dumps) - df.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={ - "color": String(255), - "name": String(255), - "features": Text, - "type": Text, - }, - index=False, - ) + df.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={ + "color": String(255), + "name": String(255), + "features": Text, + "type": Text, + }, + index=False, + ) print("Creating table {} reference".format(tbl_name)) table = get_table_connector_registry() diff --git a/superset/examples/random_time_series.py b/superset/examples/random_time_series.py index 4a2628df7a074..9a296ec2c4713 100644 --- a/superset/examples/random_time_series.py +++ b/superset/examples/random_time_series.py @@ -37,28 +37,28 @@ def load_random_time_series_data( """Loading random time series data from a zip file in the repo""" tbl_name = "random_time_series" database = database_utils.get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - url = get_example_url("random_time_series.json.gz") - pdf = pd.read_json(url, compression="gzip") - if database.backend == "presto": - pdf.ds = pd.to_datetime(pdf.ds, unit="s") - pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d %H:%M%:%S") - else: - pdf.ds = pd.to_datetime(pdf.ds, unit="s") + if not only_metadata and (not table_exists or force): + url = get_example_url("random_time_series.json.gz") + pdf = pd.read_json(url, compression="gzip") + if database.backend == "presto": + pdf.ds = pd.to_datetime(pdf.ds, unit="s") + pdf.ds = pdf.ds.dt.strftime("%Y-%m-%d %H:%M%:%S") + else: + pdf.ds = pd.to_datetime(pdf.ds, unit="s") - pdf.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={"ds": DateTime if database.backend != "presto" else String(255)}, - index=False, - ) + pdf.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={"ds": DateTime if database.backend != "presto" else String(255)}, + index=False, + ) print("Done loading table!") print("-" * 80) diff --git a/superset/examples/sf_population_polygons.py b/superset/examples/sf_population_polygons.py index 71ba34401af92..6011b82b09651 100644 --- a/superset/examples/sf_population_polygons.py +++ b/superset/examples/sf_population_polygons.py @@ -30,29 +30,29 @@ def load_sf_population_polygons( ) -> None: tbl_name = "sf_population_polygons" database = database_utils.get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) - if not only_metadata and (not table_exists or force): - url = get_example_url("sf_population.json.gz") - df = pd.read_json(url, compression="gzip") - df["contour"] = df.contour.map(json.dumps) + if not only_metadata and (not table_exists or force): + url = get_example_url("sf_population.json.gz") + df = pd.read_json(url, compression="gzip") + df["contour"] = df.contour.map(json.dumps) - df.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=500, - dtype={ - "zipcode": BigInteger, - "population": BigInteger, - "contour": Text, - "area": Float, - }, - index=False, - ) + df.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=500, + dtype={ + "zipcode": BigInteger, + "population": BigInteger, + "contour": Text, + "area": Float, + }, + index=False, + ) print("Creating table {} reference".format(tbl_name)) table = get_table_connector_registry() diff --git a/superset/examples/supported_charts_dashboard.py b/superset/examples/supported_charts_dashboard.py index aa4f404ccb0fe..551741bf7d17b 100644 --- a/superset/examples/supported_charts_dashboard.py +++ b/superset/examples/supported_charts_dashboard.py @@ -453,11 +453,11 @@ def load_supported_charts_dashboard() -> None: """Loading a dashboard featuring supported charts""" database = get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name + with database.get_sqla_engine_with_context() as engine: + schema = inspect(engine).default_schema_name - tbl_name = "birth_names" - table_exists = database.has_table_by_name(tbl_name, schema=schema) + tbl_name = "birth_names" + table_exists = database.has_table_by_name(tbl_name, schema=schema) if table_exists: table = get_table_connector_registry() diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py index 4a18f806eae56..b65ad68d1af62 100644 --- a/superset/examples/world_bank.py +++ b/superset/examples/world_bank.py @@ -51,37 +51,38 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s """Loads the world bank health dataset, slices and a dashboard""" tbl_name = "wb_health_population" database = superset.utils.database.get_example_database() - engine = database.get_sqla_engine() - schema = inspect(engine).default_schema_name - table_exists = database.has_table_by_name(tbl_name) + with database.get_sqla_engine_with_context() as engine: - if not only_metadata and (not table_exists or force): - url = get_example_url("countries.json.gz") - pdf = pd.read_json(url, compression="gzip") - pdf.columns = [col.replace(".", "_") for col in pdf.columns] - if database.backend == "presto": - pdf.year = pd.to_datetime(pdf.year) - pdf.year = pdf.year.dt.strftime("%Y-%m-%d %H:%M%:%S") - else: - pdf.year = pd.to_datetime(pdf.year) - pdf = pdf.head(100) if sample else pdf + schema = inspect(engine).default_schema_name + table_exists = database.has_table_by_name(tbl_name) - pdf.to_sql( - tbl_name, - engine, - schema=schema, - if_exists="replace", - chunksize=50, - dtype={ - # TODO(bkyryliuk): use TIMESTAMP type for presto - "year": DateTime if database.backend != "presto" else String(255), - "country_code": String(3), - "country_name": String(255), - "region": String(255), - }, - method="multi", - index=False, - ) + if not only_metadata and (not table_exists or force): + url = get_example_url("countries.json.gz") + pdf = pd.read_json(url, compression="gzip") + pdf.columns = [col.replace(".", "_") for col in pdf.columns] + if database.backend == "presto": + pdf.year = pd.to_datetime(pdf.year) + pdf.year = pdf.year.dt.strftime("%Y-%m-%d %H:%M%:%S") + else: + pdf.year = pd.to_datetime(pdf.year) + pdf = pdf.head(100) if sample else pdf + + pdf.to_sql( + tbl_name, + engine, + schema=schema, + if_exists="replace", + chunksize=50, + dtype={ + # TODO(bkyryliuk): use TIMESTAMP type for presto + "year": DateTime if database.backend != "presto" else String(255), + "country_code": String(3), + "country_name": String(255), + "region": String(255), + }, + method="multi", + index=False, + ) print("Creating table [wb_health_population] reference") table = get_table_connector_registry() diff --git a/superset/models/core.py b/superset/models/core.py index 86b9eb1bde759..020f04f28af02 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -369,12 +369,9 @@ def get_sqla_engine_with_context( nullpool: bool = True, source: Optional[utils.QuerySource] = None, ) -> Engine: - try: - yield self.get_sqla_engine(schema=schema, nullpool=nullpool, source=source) - except Exception as ex: - raise self.db_engine_spec.get_dbapi_mapped_exception(ex) + yield self._get_sqla_engine(schema=schema, nullpool=nullpool, source=source) - def get_sqla_engine( + def _get_sqla_engine( self, schema: Optional[str] = None, nullpool: bool = True, @@ -392,7 +389,7 @@ def get_sqla_engine( ) masked_url = self.get_password_masked_url(sqlalchemy_url) - logger.debug("Database.get_sqla_engine(). Masked URL: %s", str(masked_url)) + logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url)) params = extra.get("engine_params", {}) if nullpool: @@ -442,7 +439,7 @@ def get_df( # pylint: disable=too-many-locals mutator: Optional[Callable[[pd.DataFrame], None]] = None, ) -> pd.DataFrame: sqls = self.db_engine_spec.parse_sql(sql) - engine = self.get_sqla_engine(schema) + engine = self._get_sqla_engine(schema) def needs_conversion(df_series: pd.Series) -> bool: return ( @@ -487,7 +484,7 @@ def _log_query(sql: str) -> None: return df def compile_sqla_query(self, qry: Select, schema: Optional[str] = None) -> str: - engine = self.get_sqla_engine(schema=schema) + engine = self._get_sqla_engine(schema=schema) sql = str(qry.compile(engine, compile_kwargs={"literal_binds": True})) @@ -508,7 +505,7 @@ def select_star( # pylint: disable=too-many-arguments cols: Optional[List[Dict[str, Any]]] = None, ) -> str: """Generates a ``select *`` statement in the proper dialect""" - eng = self.get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB) + eng = self._get_sqla_engine(schema=schema, source=utils.QuerySource.SQL_LAB) return self.db_engine_spec.select_star( self, table_name, @@ -533,7 +530,7 @@ def safe_sqlalchemy_uri(self) -> str: @property def inspector(self) -> Inspector: - engine = self.get_sqla_engine() + engine = self._get_sqla_engine() return sqla.inspect(engine) @cache_util.memoized_func( @@ -674,7 +671,7 @@ def get_table(self, table_name: str, schema: Optional[str] = None) -> Table: meta, schema=schema or None, autoload=True, - autoload_with=self.get_sqla_engine(), + autoload_with=self._get_sqla_engine(), ) def get_table_comment( @@ -765,11 +762,11 @@ def get_perm(self) -> str: return self.perm # type: ignore def has_table(self, table: Table) -> bool: - engine = self.get_sqla_engine() + engine = self._get_sqla_engine() return engine.has_table(table.table_name, table.schema or None) def has_table_by_name(self, table_name: str, schema: Optional[str] = None) -> bool: - engine = self.get_sqla_engine() + engine = self._get_sqla_engine() return engine.has_table(table_name, schema) @classmethod @@ -788,7 +785,7 @@ def _has_view( return view_name in view_names def has_view(self, view_name: str, schema: Optional[str] = None) -> bool: - engine = self.get_sqla_engine() + engine = self._get_sqla_engine() return engine.run_callable(self._has_view, engine.dialect, view_name, schema) def has_view_by_name(self, view_name: str, schema: Optional[str] = None) -> bool: diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index 57567e61641c4..a98d76e58162e 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -224,8 +224,9 @@ def charts(self) -> List[str]: @property def sqla_metadata(self) -> None: # pylint: disable=no-member - meta = MetaData(bind=self.get_sqla_engine()) - meta.reflect() + with self.get_sqla_engine_with_context() as engine: + meta = MetaData(bind=engine) + meta.reflect() @property def status(self) -> utils.DashboardStatus: diff --git a/superset/models/filter_set.py b/superset/models/filter_set.py index 2d3b218793dcf..4bbef264900d6 100644 --- a/superset/models/filter_set.py +++ b/superset/models/filter_set.py @@ -55,8 +55,9 @@ def url(self) -> str: @property def sqla_metadata(self) -> None: # pylint: disable=no-member - meta = MetaData(bind=self.get_sqla_engine()) - meta.reflect() + with self.get_sqla_engine_with_context() as engine: + meta = MetaData(bind=engine) + meta.reflect() @property def changed_by_name(self) -> str: diff --git a/superset/models/helpers.py b/superset/models/helpers.py index cb314de80275c..1aa195fed22f4 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -1281,13 +1281,13 @@ def values_for_column(self, column_name: str, limit: int = 10000) -> List[Any]: if limit: qry = qry.limit(limit) - engine = self.database.get_sqla_engine() # type: ignore - sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) - sql = self._apply_cte(sql, cte) - sql = self.mutate_query_from_config(sql) + with self.database.get_sqla_engine_with_context() as engine: # type: ignore + sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) + sql = self._apply_cte(sql, cte) + sql = self.mutate_query_from_config(sql) - df = pd.read_sql_query(sql=sql, con=engine) - return df[column_name].to_list() + df = pd.read_sql_query(sql=sql, con=engine) + return df[column_name].to_list() def get_timestamp_expression( self, diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 96afc7f51ed90..6d9903c8f0009 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -463,61 +463,66 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca ) ) - engine = database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB) - # Sharing a single connection and cursor across the - # execution of all statements (if many) - with closing(engine.raw_connection()) as conn: - # closing the connection closes the cursor as well - cursor = conn.cursor() - cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query) - if cancel_query_id is not None: - query.set_extra_json_key(cancel_query_key, cancel_query_id) - session.commit() - statement_count = len(statements) - for i, statement in enumerate(statements): - # Check if stopped - session.refresh(query) - if query.status == QueryStatus.STOPPED: - payload.update({"status": query.status}) - return payload - - # For CTAS we create the table only on the last statement - apply_ctas = query.select_as_cta and ( - query.ctas_method == CtasMethod.VIEW - or (query.ctas_method == CtasMethod.TABLE and i == len(statements) - 1) - ) - - # Run statement - msg = f"Running statement {i+1} out of {statement_count}" - logger.info("Query %s: %s", str(query_id), msg) - query.set_extra_json_key("progress", msg) - session.commit() - try: - result_set = execute_sql_statement( - statement, - query, - session, - cursor, - log_params, - apply_ctas, - ) - except SqlLabQueryStoppedException: - payload.update({"status": QueryStatus.STOPPED}) - return payload - except Exception as ex: # pylint: disable=broad-except - msg = str(ex) - prefix_message = ( - f"[Statement {i+1} out of {statement_count}]" - if statement_count > 1 - else "" + with database.get_sqla_engine_with_context( + query.schema, source=QuerySource.SQL_LAB + ) as engine: + # Sharing a single connection and cursor across the + # execution of all statements (if many) + with closing(engine.raw_connection()) as conn: + # closing the connection closes the cursor as well + cursor = conn.cursor() + cancel_query_id = db_engine_spec.get_cancel_query_id(cursor, query) + if cancel_query_id is not None: + query.set_extra_json_key(cancel_query_key, cancel_query_id) + session.commit() + statement_count = len(statements) + for i, statement in enumerate(statements): + # Check if stopped + session.refresh(query) + if query.status == QueryStatus.STOPPED: + payload.update({"status": query.status}) + return payload + + # For CTAS we create the table only on the last statement + apply_ctas = query.select_as_cta and ( + query.ctas_method == CtasMethod.VIEW + or ( + query.ctas_method == CtasMethod.TABLE + and i == len(statements) - 1 + ) ) - payload = handle_query_error( - ex, query, session, payload, prefix_message - ) - return payload - # Commit the connection so CTA queries will create the table. - conn.commit() + # Run statement + msg = f"Running statement {i+1} out of {statement_count}" + logger.info("Query %s: %s", str(query_id), msg) + query.set_extra_json_key("progress", msg) + session.commit() + try: + result_set = execute_sql_statement( + statement, + query, + session, + cursor, + log_params, + apply_ctas, + ) + except SqlLabQueryStoppedException: + payload.update({"status": QueryStatus.STOPPED}) + return payload + except Exception as ex: # pylint: disable=broad-except + msg = str(ex) + prefix_message = ( + f"[Statement {i+1} out of {statement_count}]" + if statement_count > 1 + else "" + ) + payload = handle_query_error( + ex, query, session, payload, prefix_message + ) + return payload + + # Commit the connection so CTA queries will create the table. + conn.commit() # Success, updating the query entry in database query.rows = result_set.size @@ -622,10 +627,11 @@ def cancel_query(query: Query) -> bool: if cancel_query_id is None: return False - engine = query.database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB) - - with closing(engine.raw_connection()) as conn: - with closing(conn.cursor()) as cursor: - return query.database.db_engine_spec.cancel_query( - cursor, query, cancel_query_id - ) + with query.database.get_sqla_engine_with_context( + query.schema, source=QuerySource.SQL_LAB + ) as engine: + with closing(engine.raw_connection()) as conn: + with closing(conn.cursor()) as cursor: + return query.database.db_engine_spec.cancel_query( + cursor, query, cancel_query_id + ) diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index 70b324c900736..37375e484dec5 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -162,16 +162,18 @@ def validate( statements = parsed_query.get_statements() logger.info("Validating %i statement(s)", len(statements)) - engine = database.get_sqla_engine(schema, source=QuerySource.SQL_LAB) - # Sharing a single connection and cursor across the - # execution of all statements (if many) - annotations: List[SQLValidationAnnotation] = [] - with closing(engine.raw_connection()) as conn: - cursor = conn.cursor() - for statement in parsed_query.get_statements(): - annotation = cls.validate_statement(statement, database, cursor) - if annotation: - annotations.append(annotation) - logger.debug("Validation found %i error(s)", len(annotations)) + with database.get_sqla_engine_with_context( + schema, source=QuerySource.SQL_LAB + ) as engine: + # Sharing a single connection and cursor across the + # execution of all statements (if many) + annotations: List[SQLValidationAnnotation] = [] + with closing(engine.raw_connection()) as conn: + cursor = conn.cursor() + for statement in parsed_query.get_statements(): + annotation = cls.validate_statement(statement, database, cursor) + if annotation: + annotations.append(annotation) + logger.debug("Validation found %i error(s)", len(annotations)) return annotations diff --git a/superset/utils/core.py b/superset/utils/core.py index cd992250ee7d6..26b9cf0cf0efc 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1278,8 +1278,8 @@ def get_example_default_schema() -> Optional[str]: Return the default schema of the examples database, if any. """ database = get_example_database() - engine = database.get_sqla_engine() - return inspect(engine).default_schema_name + with database.get_sqla_engine_with_context() as engine: + return inspect(engine).default_schema_name def backend() -> str: diff --git a/superset/utils/mock_data.py b/superset/utils/mock_data.py index 904f7ee42e88a..4b156cc10c10d 100644 --- a/superset/utils/mock_data.py +++ b/superset/utils/mock_data.py @@ -187,29 +187,29 @@ def add_data( database = get_example_database() table_exists = database.has_table_by_name(table_name) - engine = database.get_sqla_engine() - if columns is None: - if not table_exists: - raise Exception( - f"The table {table_name} does not exist. To create it you need to " - "pass a list of column names and types." - ) + with database.get_sqla_engine_with_context() as engine: + if columns is None: + if not table_exists: + raise Exception( + f"The table {table_name} does not exist. To create it you need to " + "pass a list of column names and types." + ) - inspector = inspect(engine) - columns = inspector.get_columns(table_name) + inspector = inspect(engine) + columns = inspector.get_columns(table_name) - # create table if needed - column_objects = get_column_objects(columns) - metadata = MetaData() - table = Table(table_name, metadata, *column_objects) - metadata.create_all(engine) + # create table if needed + column_objects = get_column_objects(columns) + metadata = MetaData() + table = Table(table_name, metadata, *column_objects) + metadata.create_all(engine) - if not append: - engine.execute(table.delete()) + if not append: + engine.execute(table.delete()) - data = generate_data(columns, num_rows) - engine.execute(table.insert(), data) + data = generate_data(columns, num_rows) + engine.execute(table.insert(), data) def get_column_objects(columns: List[ColumnInfo]) -> List[Column]: diff --git a/superset/views/core.py b/superset/views/core.py index cc1865452a856..3252bb92371a7 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1395,11 +1395,11 @@ def testconn(self) -> FlaskResponse: ) database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) - engine = database.get_sqla_engine() - with closing(engine.raw_connection()) as conn: - if engine.dialect.do_ping(conn): - return json_success('"OK"') + with database.get_sqla_engine_with_context() as engine: + with closing(engine.raw_connection()) as conn: + if engine.dialect.do_ping(conn): + return json_success('"OK"') raise DBAPIError(None, None, None) except CertificateException as ex: diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index efbc6bf7f07d3..8908c3e22782f 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -171,7 +171,7 @@ def __call__(self) -> Database: return self._db def _load_lazy_data_to_decouple_from_session(self) -> None: - self._db.get_sqla_engine() # type: ignore + self._db._get_sqla_engine() # type: ignore self._db.backend # type: ignore def remove(self) -> None: @@ -336,37 +336,38 @@ def physical_dataset(): from superset.connectors.sqla.utils import get_identifier_quoter example_database = get_example_database() - engine = example_database.get_sqla_engine() - quoter = get_identifier_quoter(engine.name) - # sqlite can only execute one statement at a time - engine.execute( - f""" - CREATE TABLE IF NOT EXISTS physical_dataset( - col1 INTEGER, - col2 VARCHAR(255), - col3 DECIMAL(4,2), - col4 VARCHAR(255), - col5 TIMESTAMP DEFAULT '1970-01-01 00:00:01', - col6 TIMESTAMP DEFAULT '1970-01-01 00:00:01', - {quoter('time column with spaces')} TIMESTAMP DEFAULT '1970-01-01 00:00:01' - ); - """ - ) - engine.execute( + + with example_database.get_sqla_engine_with_context() as engine: + quoter = get_identifier_quoter(engine.name) + # sqlite can only execute one statement at a time + engine.execute( + f""" + CREATE TABLE IF NOT EXISTS physical_dataset( + col1 INTEGER, + col2 VARCHAR(255), + col3 DECIMAL(4,2), + col4 VARCHAR(255), + col5 TIMESTAMP DEFAULT '1970-01-01 00:00:01', + col6 TIMESTAMP DEFAULT '1970-01-01 00:00:01', + {quoter('time column with spaces')} TIMESTAMP DEFAULT '1970-01-01 00:00:01' + ); + """ + ) + engine.execute( + """ + INSERT INTO physical_dataset values + (0, 'a', 1.0, NULL, '2000-01-01 00:00:00', '2002-01-03 00:00:00', '2002-01-03 00:00:00'), + (1, 'b', 1.1, NULL, '2000-01-02 00:00:00', '2002-02-04 00:00:00', '2002-02-04 00:00:00'), + (2, 'c', 1.2, NULL, '2000-01-03 00:00:00', '2002-03-07 00:00:00', '2002-03-07 00:00:00'), + (3, 'd', 1.3, NULL, '2000-01-04 00:00:00', '2002-04-12 00:00:00', '2002-04-12 00:00:00'), + (4, 'e', 1.4, NULL, '2000-01-05 00:00:00', '2002-05-11 00:00:00', '2002-05-11 00:00:00'), + (5, 'f', 1.5, NULL, '2000-01-06 00:00:00', '2002-06-13 00:00:00', '2002-06-13 00:00:00'), + (6, 'g', 1.6, NULL, '2000-01-07 00:00:00', '2002-07-15 00:00:00', '2002-07-15 00:00:00'), + (7, 'h', 1.7, NULL, '2000-01-08 00:00:00', '2002-08-18 00:00:00', '2002-08-18 00:00:00'), + (8, 'i', 1.8, NULL, '2000-01-09 00:00:00', '2002-09-20 00:00:00', '2002-09-20 00:00:00'), + (9, 'j', 1.9, NULL, '2000-01-10 00:00:00', '2002-10-22 00:00:00', '2002-10-22 00:00:00'); """ - INSERT INTO physical_dataset values - (0, 'a', 1.0, NULL, '2000-01-01 00:00:00', '2002-01-03 00:00:00', '2002-01-03 00:00:00'), - (1, 'b', 1.1, NULL, '2000-01-02 00:00:00', '2002-02-04 00:00:00', '2002-02-04 00:00:00'), - (2, 'c', 1.2, NULL, '2000-01-03 00:00:00', '2002-03-07 00:00:00', '2002-03-07 00:00:00'), - (3, 'd', 1.3, NULL, '2000-01-04 00:00:00', '2002-04-12 00:00:00', '2002-04-12 00:00:00'), - (4, 'e', 1.4, NULL, '2000-01-05 00:00:00', '2002-05-11 00:00:00', '2002-05-11 00:00:00'), - (5, 'f', 1.5, NULL, '2000-01-06 00:00:00', '2002-06-13 00:00:00', '2002-06-13 00:00:00'), - (6, 'g', 1.6, NULL, '2000-01-07 00:00:00', '2002-07-15 00:00:00', '2002-07-15 00:00:00'), - (7, 'h', 1.7, NULL, '2000-01-08 00:00:00', '2002-08-18 00:00:00', '2002-08-18 00:00:00'), - (8, 'i', 1.8, NULL, '2000-01-09 00:00:00', '2002-09-20 00:00:00', '2002-09-20 00:00:00'), - (9, 'j', 1.9, NULL, '2000-01-10 00:00:00', '2002-10-22 00:00:00', '2002-10-22 00:00:00'); - """ - ) + ) dataset = SqlaTable( table_name="physical_dataset", diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py index 4426fa756ff55..64c9b260c4ab9 100644 --- a/tests/integration_tests/databases/commands_tests.py +++ b/tests/integration_tests/databases/commands_tests.py @@ -641,7 +641,7 @@ def test_import_v1_rollback(self, mock_import_dataset): class TestTestConnectionDatabaseCommand(SupersetTestCase): - @mock.patch("superset.databases.dao.Database.get_sqla_engine") + @mock.patch("superset.databases.dao.Database._get_sqla_engine") @mock.patch( "superset.databases.commands.test_connection.event_logger.log_with_context" ) @@ -664,7 +664,7 @@ def test_connection_db_exception( ) mock_event_logger.assert_called() - @mock.patch("superset.databases.dao.Database.get_sqla_engine") + @mock.patch("superset.databases.dao.Database._get_sqla_engine") @mock.patch( "superset.databases.commands.test_connection.event_logger.log_with_context" ) @@ -713,7 +713,7 @@ def test_connection_do_ping_timeout( == SupersetErrorType.CONNECTION_DATABASE_TIMEOUT ) - @mock.patch("superset.databases.dao.Database.get_sqla_engine") + @mock.patch("superset.databases.dao.Database._get_sqla_engine") @mock.patch( "superset.databases.commands.test_connection.event_logger.log_with_context" ) @@ -738,7 +738,7 @@ def test_connection_superset_security_connection( mock_event_logger.assert_called() - @mock.patch("superset.databases.dao.Database.get_sqla_engine") + @mock.patch("superset.databases.dao.Database._get_sqla_engine") @mock.patch( "superset.databases.commands.test_connection.event_logger.log_with_context" ) diff --git a/tests/integration_tests/db_engine_specs/bigquery_tests.py b/tests/integration_tests/db_engine_specs/bigquery_tests.py index 1825f9587cc80..4e09077f5ccee 100644 --- a/tests/integration_tests/db_engine_specs/bigquery_tests.py +++ b/tests/integration_tests/db_engine_specs/bigquery_tests.py @@ -227,8 +227,10 @@ def test_df_to_sql(self, mock_get_engine): return_value="account_info" ) - mock_get_engine.return_value.url.host = "google-host" - mock_get_engine.return_value.dialect.credentials_info = "secrets" + mock_get_engine.return_value.__enter__.return_value.url.host = "google-host" + mock_get_engine.return_value.__enter__.return_value.dialect.credentials_info = ( + "secrets" + ) BigQueryEngineSpec.df_to_sql( database=database, diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py index 991d3f759c5ee..366648effa988 100644 --- a/tests/integration_tests/db_engine_specs/hive_tests.py +++ b/tests/integration_tests/db_engine_specs/hive_tests.py @@ -204,7 +204,9 @@ def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g): mock_database = mock.MagicMock() mock_database.get_df.return_value.empty = False mock_execute = mock.MagicMock(return_value=True) - mock_database.get_sqla_engine.return_value.execute = mock_execute + mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = ( + mock_execute + ) table_name = "foobar" with app.app_context(): @@ -229,7 +231,9 @@ def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g): mock_database = mock.MagicMock() mock_database.get_df.return_value.empty = False mock_execute = mock.MagicMock(return_value=True) - mock_database.get_sqla_engine.return_value.execute = mock_execute + mock_database.get_sqla_engine_with_context.return_value.__enter__.return_value.execute = ( + mock_execute + ) table_name = "foobar" schema = "schema" diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index d37a04645f8cb..a38617e8a9a85 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -37,12 +37,13 @@ def test_get_datatype_presto(self): def test_get_view_names_with_schema(self): database = mock.MagicMock() mock_execute = mock.MagicMock() - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = ( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = ( mock_execute ) - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock( return_value=[["a", "b,", "c"], ["d", "e"]] ) + schema = "schema" result = PrestoEngineSpec.get_view_names(database, mock.Mock(), schema) mock_execute.assert_called_once_with( @@ -60,10 +61,10 @@ def test_get_view_names_with_schema(self): def test_get_view_names_without_schema(self): database = mock.MagicMock() mock_execute = mock.MagicMock() - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = ( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = ( mock_execute ) - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = mock.MagicMock( return_value=[["a", "b,", "c"], ["d", "e"]] ) result = PrestoEngineSpec.get_view_names(database, mock.Mock(), None) @@ -821,13 +822,13 @@ def test_get_create_view(self): mock_execute = mock.MagicMock() mock_fetchall = mock.MagicMock(return_value=[["a", "b,", "c"], ["d", "e"]]) database = mock.MagicMock() - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = ( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = ( mock_execute ) - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.fetchall = ( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.fetchall = ( mock_fetchall ) - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.poll.return_value = ( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.poll.return_value = ( False ) schema = "schema" @@ -839,7 +840,7 @@ def test_get_create_view(self): def test_get_create_view_exception(self): mock_execute = mock.MagicMock(side_effect=Exception()) database = mock.MagicMock() - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = ( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = ( mock_execute ) schema = "schema" @@ -852,7 +853,7 @@ def test_get_create_view_database_error(self): mock_execute = mock.MagicMock(side_effect=DatabaseError()) database = mock.MagicMock() - database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value.execute = ( + database.get_sqla_engine_with_context.return_value.__enter__.return_value.raw_connection.return_value.cursor.return_value.execute = ( mock_execute ) schema = "schema" diff --git a/tests/integration_tests/fixtures/unicode_dashboard.py b/tests/integration_tests/fixtures/unicode_dashboard.py index 9368df7614a9f..78178bcde7551 100644 --- a/tests/integration_tests/fixtures/unicode_dashboard.py +++ b/tests/integration_tests/fixtures/unicode_dashboard.py @@ -51,8 +51,8 @@ def load_unicode_data(): yield with app.app_context(): - engine = get_example_database().get_sqla_engine() - engine.execute("DROP TABLE IF EXISTS unicode_test") + with get_example_database().get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE IF EXISTS unicode_test") @pytest.fixture() diff --git a/tests/integration_tests/fixtures/world_bank_dashboard.py b/tests/integration_tests/fixtures/world_bank_dashboard.py index e29962a8c9787..561bbe10b2709 100644 --- a/tests/integration_tests/fixtures/world_bank_dashboard.py +++ b/tests/integration_tests/fixtures/world_bank_dashboard.py @@ -64,8 +64,8 @@ def load_world_bank_data(): yield with app.app_context(): - engine = get_example_database().get_sqla_engine() - engine.execute("DROP TABLE IF EXISTS wb_health_population") + with get_example_database().get_sqla_engine_with_context() as engine: + engine.execute("DROP TABLE IF EXISTS wb_health_population") @pytest.fixture() diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index 3e13664b63e36..f187eadfbb27a 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -164,7 +164,7 @@ def test_impersonate_user_presto(self, mocked_create_engine): database_name="test_database", sqlalchemy_uri=uri, extra=extra ) model.impersonate_user = True - model.get_sqla_engine() + model._get_sqla_engine() call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "presto://gamma@localhost" @@ -177,7 +177,7 @@ def test_impersonate_user_presto(self, mocked_create_engine): } model.impersonate_user = False - model.get_sqla_engine() + model._get_sqla_engine() call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "presto://localhost" @@ -197,7 +197,7 @@ def test_impersonate_user_trino(self, mocked_create_engine): database_name="test_database", sqlalchemy_uri="trino://localhost" ) model.impersonate_user = True - model.get_sqla_engine() + model._get_sqla_engine() call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "trino://localhost" @@ -209,7 +209,7 @@ def test_impersonate_user_trino(self, mocked_create_engine): ) model.impersonate_user = True - model.get_sqla_engine() + model._get_sqla_engine() call_args = mocked_create_engine.call_args assert ( @@ -242,7 +242,7 @@ def test_impersonate_user_hive(self, mocked_create_engine): database_name="test_database", sqlalchemy_uri=uri, extra=extra ) model.impersonate_user = True - model.get_sqla_engine() + model._get_sqla_engine() call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "hive://localhost" @@ -255,7 +255,7 @@ def test_impersonate_user_hive(self, mocked_create_engine): } model.impersonate_user = False - model.get_sqla_engine() + model._get_sqla_engine() call_args = mocked_create_engine.call_args assert str(call_args[0][0]) == "hive://localhost" @@ -380,21 +380,7 @@ def test_get_sqla_engine(self, mocked_create_engine): ) mocked_create_engine.side_effect = Exception() with self.assertRaises(SupersetException): - model.get_sqla_engine() - - # todo(hughhh): update this test - # @mock.patch("superset.models.core.create_engine") - # def test_get_sqla_engine_with_context(self, mocked_create_engine): - # model = Database( - # database_name="test_database", - # sqlalchemy_uri="mysql://root@localhost", - # ) - # model.db_engine_spec.get_dbapi_exception_mapping = mock.Mock( - # return_value={Exception: SupersetException} - # ) - # mocked_create_engine.side_effect = Exception() - # with self.assertRaises(SupersetException): - # model.get_sqla_engine() + model._get_sqla_engine() class TestSqlaTableModel(SupersetTestCase): diff --git a/tests/integration_tests/sql_validator_tests.py b/tests/integration_tests/sql_validator_tests.py index ff4c74fa45fba..d2f6e7108d42a 100644 --- a/tests/integration_tests/sql_validator_tests.py +++ b/tests/integration_tests/sql_validator_tests.py @@ -174,7 +174,9 @@ class TestPrestoValidator(SupersetTestCase): def setUp(self): self.validator = PrestoDBSQLValidator self.database = MagicMock() - self.database_engine = self.database.get_sqla_engine.return_value + self.database_engine = ( + self.database.get_sqla_engine_with_context.return_value.__enter__.return_value + ) self.database_conn = self.database_engine.raw_connection.return_value self.database_cursor = self.database_conn.cursor.return_value self.database_cursor.poll.return_value = None diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index bee9b08114a40..ed37eece96711 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -733,7 +733,7 @@ def test_execute_sql_statements(self, mock_execute_sql_statement, mock_get_query mock_query = mock.MagicMock() mock_query.database.allow_run_async = False mock_cursor = mock.MagicMock() - mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = ( + mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = ( mock_cursor ) mock_query.database.db_engine_spec.run_multiple_statements_as_one = False @@ -786,7 +786,7 @@ def test_execute_sql_statements_no_results_backend( mock_query = mock.MagicMock() mock_query.database.allow_run_async = True mock_cursor = mock.MagicMock() - mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = ( + mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = ( mock_cursor ) mock_query.database.db_engine_spec.run_multiple_statements_as_one = False @@ -836,7 +836,7 @@ def test_execute_sql_statements_ctas( mock_query = mock.MagicMock() mock_query.database.allow_run_async = False mock_cursor = mock.MagicMock() - mock_query.database.get_sqla_engine.return_value.raw_connection.return_value.cursor.return_value = ( + mock_query.database.get_sqla_engine_with_context().__enter__().raw_connection().cursor.return_value = ( mock_cursor ) mock_query.database.db_engine_spec.run_multiple_statements_as_one = False