From 6b33a0851e9b63a329cca93b992a6e4e76a85a16 Mon Sep 17 00:00:00 2001 From: John Bodley Date: Tue, 20 Apr 2021 14:18:28 +1200 Subject: [PATCH 1/2] fix(hive): Use parquet rather than textfile when uploading CSV files --- superset/db_engine_specs/base.py | 83 ++++-------- superset/db_engine_specs/bigquery.py | 54 +++++--- superset/db_engine_specs/hive.py | 160 +++++++++++------------- superset/views/database/views.py | 113 ++++++++--------- tests/csv_upload_tests.py | 33 ++--- tests/db_engine_specs/bigquery_tests.py | 34 +++-- tests/db_engine_specs/hive_tests.py | 122 ++++-------------- 7 files changed, 238 insertions(+), 361 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 5cdb2a0307059..6d95b73afa5f8 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -613,50 +613,41 @@ def set_or_update_query_limit(cls, sql: str, limit: int) -> str: parsed_query = sql_parse.ParsedQuery(sql) return parsed_query.set_or_update_query_limit(limit) - @staticmethod - def csv_to_df(**kwargs: Any) -> pd.DataFrame: - """Read csv into Pandas DataFrame - :param kwargs: params to be passed to DataFrame.read_csv - :return: Pandas DataFrame containing data from csv - """ - kwargs["encoding"] = "utf-8" - kwargs["iterator"] = True - chunks = pd.read_csv(**kwargs) - df = pd.concat(chunk for chunk in chunks) - return df - - @classmethod - def df_to_sql(cls, df: pd.DataFrame, **kwargs: Any) -> None: - """Upload data from a Pandas DataFrame to a database. For - regular engines this calls the DataFrame.to_sql() method. Can be - overridden for engines that don't work well with to_sql(), e.g. - BigQuery. - :param df: Dataframe with data to be uploaded - :param kwargs: kwargs to be passed to to_sql() method - """ - df.to_sql(**kwargs) - @classmethod - def create_table_from_csv( # pylint: disable=too-many-arguments + def df_to_sql( cls, - filename: str, - table: Table, database: "Database", - csv_to_df_kwargs: Dict[str, Any], - df_to_sql_kwargs: Dict[str, Any], + table: Table, + df: pd.DataFrame, + to_sql_kwargs: Dict[str, Any], ) -> None: """ - Create table from contents of a csv. Note: this method does not create - metadata for the table. + Upload data from a Pandas DataFrame to a database. + + For regular engines this calls the `pandas.DataFrame.to_sql` method. Can be + overridden for engines that don't work well with this method, e.g. Hive and + BigQuery. + + Note this method does not create metadata for the table. + + :param database: The database to upload the data to + :param table: The table to upload the data to + :param df: The dataframe with data to be uploaded + :param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method """ - df = cls.csv_to_df(filepath_or_buffer=filename, **csv_to_df_kwargs) + engine = cls.get_engine(database) + to_sql_kwargs["name"] = table.table + if table.schema: - # only add schema when it is preset and non empty - df_to_sql_kwargs["schema"] = table.schema + + # Only add schema when it is preset and non empty. + to_sql_kwargs["schema"] = table.schema + if engine.dialect.supports_multivalues_insert: - df_to_sql_kwargs["method"] = "multi" - cls.df_to_sql(df=df, con=engine, **df_to_sql_kwargs) + to_sql_kwargs["method"] = "multi" + + df.to_sql(con=engine, **to_sql_kwargs) @classmethod def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]: @@ -669,28 +660,6 @@ def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]: """ return None - @classmethod - def create_table_from_excel( # pylint: disable=too-many-arguments - cls, - filename: str, - table: Table, - database: "Database", - excel_to_df_kwargs: Dict[str, Any], - df_to_sql_kwargs: Dict[str, Any], - ) -> None: - """ - Create table from contents of a excel. Note: this method does not create - metadata for the table. - """ - df = pd.read_excel(io=filename, **excel_to_df_kwargs) - engine = cls.get_engine(database) - if table.schema: - # only add schema when it is preset and non empty - df_to_sql_kwargs["schema"] = table.schema - if engine.dialect.supports_multivalues_insert: - df_to_sql_kwargs["method"] = "multi" - cls.df_to_sql(df=df, con=engine, **df_to_sql_kwargs) - @classmethod def get_all_datasource_names( cls, database: "Database", datasource_type: str diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index 59f61d14ce58f..fd34bc9887cb3 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -26,6 +26,7 @@ from superset.db_engine_specs.base import BaseEngineSpec from superset.errors import SupersetErrorType +from superset.sql_parse import Table from superset.utils import core as utils if TYPE_CHECKING: @@ -228,16 +229,26 @@ def epoch_ms_to_dttm(cls) -> str: return "TIMESTAMP_MILLIS({col})" @classmethod - def df_to_sql(cls, df: pd.DataFrame, **kwargs: Any) -> None: + def df_to_sql( + cls, + database: "Database", + table: Table, + df: pd.DataFrame, + to_sql_kwargs: Dict[str, Any], + ) -> None: """ - Upload data from a Pandas DataFrame to BigQuery. Calls - `DataFrame.to_gbq()` which requires `pandas_gbq` to be installed. + Upload data from a Pandas DataFrame to a database. - :param df: Dataframe with data to be uploaded - :param kwargs: kwargs to be passed to to_gbq() method. Requires that `schema`, - `name` and `con` are present in kwargs. `name` and `schema` are combined - and passed to `to_gbq()` as `destination_table`. + Calls `pandas_gbq.DataFrame.to_gbq` which requires `pandas_gbq` to be installed. + + Note this method does not create metadata for the table. + + :param database: The database to upload the data to + :param table: The table to upload the data to + :param df: The dataframe with data to be uploaded + :param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method """ + try: import pandas_gbq from google.oauth2 import service_account @@ -248,22 +259,25 @@ def df_to_sql(cls, df: pd.DataFrame, **kwargs: Any) -> None: "to upload data to BigQuery" ) - if not ("name" in kwargs and "schema" in kwargs and "con" in kwargs): - raise Exception("name, schema and con need to be defined in kwargs") + if not table.schema: + raise Exception("The table schema must be defined") - gbq_kwargs = {} - gbq_kwargs["project_id"] = kwargs["con"].engine.url.host - gbq_kwargs["destination_table"] = f"{kwargs.pop('schema')}.{kwargs.pop('name')}" + engine = cls.get_engine(database) + to_gbq_kwargs = {"destination_table": str(table), "project_id": engine.url.host} + + # Add credentials if they are set on the SQLAlchemy dialect. + creds = engine.dialect.credentials_info - # add credentials if they are set on the SQLAlchemy Dialect: - creds = kwargs["con"].dialect.credentials_info if creds: - credentials = service_account.Credentials.from_service_account_info(creds) - gbq_kwargs["credentials"] = credentials + to_gbq_kwargs[ + "credentials" + ] = service_account.Credentials.from_service_account_info(creds) - # Only pass through supported kwargs + # Only pass through supported kwargs. supported_kwarg_keys = {"if_exists"} + for key in supported_kwarg_keys: - if key in kwargs: - gbq_kwargs[key] = kwargs[key] - pandas_gbq.to_gbq(df, **gbq_kwargs) + if key in to_sql_kwargs: + to_gbq_kwargs[key] = to_sql_kwargs[key] + + pandas_gbq.to_gbq(df, **to_gbq_kwargs) diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 4234ddc63992b..0bbf7159ecbaf 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -17,12 +17,16 @@ import logging import os import re +import tempfile import time from datetime import datetime from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING from urllib import parse +import numpy as np import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq from flask import g from sqlalchemy import Column, text from sqlalchemy.engine.base import Engine @@ -54,6 +58,15 @@ def upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str: + """ + Upload the file to S3. + + :param filename: The file to upload + :param upload_prefix: The S3 prefix + :param table: The table that will be created + :returns: The S3 location of the table + """ + # Optional dependency import boto3 # pylint: disable=import-error @@ -152,89 +165,37 @@ def fetch_data( return [] @classmethod - def get_create_table_stmt( # pylint: disable=too-many-arguments - cls, - table: Table, - schema_definition: str, - location: str, - delim: str, - header_line_count: Optional[int], - null_values: Optional[List[str]], - ) -> text: - tblproperties = [] - # available options: - # https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL - # TODO(bkyryliuk): figure out what to do with the skip rows field. - params: Dict[str, str] = { - "delim": delim, - "location": location, - } - if header_line_count is not None and header_line_count >= 0: - header_line_count += 1 - tblproperties.append("'skip.header.line.count'=:header_line_count") - params["header_line_count"] = str(header_line_count) - if null_values: - # hive only supports 1 value for the null format - tblproperties.append("'serialization.null.format'=:null_value") - params["null_value"] = null_values[0] - - if tblproperties: - tblproperties_stmt = f"tblproperties ({', '.join(tblproperties)})" - sql = f"""CREATE TABLE {str(table)} ( {schema_definition} ) - ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim - STORED AS TEXTFILE LOCATION :location - {tblproperties_stmt}""" - else: - sql = f"""CREATE TABLE {str(table)} ( {schema_definition} ) - ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim - STORED AS TEXTFILE LOCATION :location""" - return sql, params - - @classmethod - def create_table_from_csv( # pylint: disable=too-many-arguments, too-many-locals + def df_to_sql( cls, - filename: str, - table: Table, database: "Database", - csv_to_df_kwargs: Dict[str, Any], - df_to_sql_kwargs: Dict[str, Any], + table: Table, + df: pd.DataFrame, + to_sql_kwargs: Dict[str, Any], ) -> None: - """Uploads a csv file and creates a superset datasource in Hive.""" - if_exists = df_to_sql_kwargs["if_exists"] - if if_exists == "append": - raise SupersetException("Append operation not currently supported") + """ + Upload data from a Pandas DataFrame to a database. - def convert_to_hive_type(col_type: str) -> str: - """maps tableschema's types to hive types""" - tableschema_to_hive_types = { - "boolean": "BOOLEAN", - "integer": "BIGINT", - "number": "DOUBLE", - "string": "STRING", - } - return tableschema_to_hive_types.get(col_type, "STRING") + The data is stored via the binary Parquet format which is both less problematic + and more performant than a text file. More specifically storing a table as a + CSV text file has severe limitations including the fact that the Hive CSV SerDe + does not support multiline fields. - upload_prefix = config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]( - database, g.user, table.schema - ) + Note this method does not create metadata for the table. - # Optional dependency - from tableschema import ( # pylint: disable=import-error - Table as TableSchemaTable, - ) + :param database: The database to upload the data to + :param: table The table to upload the data to + :param df: The dataframe with data to be uploaded + :param to_sql_kwargs: The kwargs to be passed to pandas.DataFrame.to_sql` method + """ - hive_table_schema = TableSchemaTable(filename).infer() - column_name_and_type = [] - for column_info in hive_table_schema["fields"]: - column_name_and_type.append( - "`{}` {}".format( - column_info["name"], convert_to_hive_type(column_info["type"]) - ) - ) - schema_definition = ", ".join(column_name_and_type) + engine = cls.get_engine(database) + + if to_sql_kwargs["if_exists"] == "append": + raise SupersetException("Append operation not currently supported") - # ensure table doesn't already exist - if if_exists == "fail": + if to_sql_kwargs["if_exists"] == "fail": + + # Ensure table doesn't already exist. if table.schema: table_exists = not database.get_df( f"SHOW TABLES IN {table.schema} LIKE '{table.table}'" @@ -243,24 +204,47 @@ def convert_to_hive_type(col_type: str) -> str: table_exists = not database.get_df( f"SHOW TABLES LIKE '{table.table}'" ).empty + if table_exists: raise SupersetException("Table already exists") + elif to_sql_kwargs["if_exists"] == "replace": + engine.execute(f"DROP TABLE IF EXISTS {str(table)}") - engine = cls.get_engine(database) + def _get_hive_type(dtype: np.dtype) -> str: + hive_type_by_dtype = { + np.dtype("bool"): "BOOLEAN", + np.dtype("float64"): "DOUBLE", + np.dtype("int64"): "BIGINT", + np.dtype("object"): "STRING", + } - if if_exists == "replace": - engine.execute(f"DROP TABLE IF EXISTS {str(table)}") - location = upload_to_s3(filename, upload_prefix, table) - sql, params = cls.get_create_table_stmt( - table, - schema_definition, - location, - csv_to_df_kwargs["sep"].encode().decode("unicode_escape"), - int(csv_to_df_kwargs.get("header", 0)), - csv_to_df_kwargs.get("na_values"), + return hive_type_by_dtype.get(dtype, "STRING") + + schema_definition = ", ".join( + f"`{name}` {_get_hive_type(dtype)}" for name, dtype in df.dtypes.items() ) - engine = cls.get_engine(database) - engine.execute(text(sql), **params) + + with tempfile.NamedTemporaryFile( + dir=config["UPLOAD_FOLDER"], suffix=".parquet" + ) as file: + pq.write_table(pa.Table.from_pandas(df), where=file.name) + + engine.execute( + text( + f""" + CREATE TABLE {str(table)} ({schema_definition}) + STORED AS PARQUET + LOCATION :location + """ + ), + location=upload_to_s3( + filename=file.name, + upload_prefix=config["CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC"]( + database, g.user, table.schema + ), + table=table, + ), + ) @classmethod def convert_dttm(cls, target_type: str, dttm: datetime) -> Optional[str]: diff --git a/superset/views/database/views.py b/superset/views/database/views.py index 3a68f32588d99..fb670c9f9924e 100644 --- a/superset/views/database/views.py +++ b/superset/views/database/views.py @@ -18,6 +18,7 @@ import tempfile from typing import TYPE_CHECKING +import pandas as pd from flask import flash, g, redirect from flask_appbuilder import expose, SimpleFormView from flask_appbuilder.models.sqla.interface import SQLAInterface @@ -156,7 +157,7 @@ def form_post(self, form: CsvToDatabaseForm) -> Response: ).name try: - utils.ensure_path_exists(config["UPLOAD_FOLDER"]) + utils.ensure_path_exists(app.config["UPLOAD_FOLDER"]) upload_stream_write(form.csv_file.data, uploaded_tmp_file_path) con = form.data.get("con") @@ -164,40 +165,37 @@ def form_post(self, form: CsvToDatabaseForm) -> Response: db.session.query(models.Database).filter_by(id=con.data.get("id")).one() ) - # More can be found here: - # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html - csv_to_df_kwargs = { - "sep": form.sep.data, - "header": form.header.data if form.header.data else 0, - "index_col": form.index_col.data, - "mangle_dupe_cols": form.mangle_dupe_cols.data, - "skipinitialspace": form.skipinitialspace.data, - "skiprows": form.skiprows.data, - "nrows": form.nrows.data, - "skip_blank_lines": form.skip_blank_lines.data, - "parse_dates": form.parse_dates.data, - "infer_datetime_format": form.infer_datetime_format.data, - "chunksize": 1000, - } - if form.null_values.data: - csv_to_df_kwargs["na_values"] = form.null_values.data - csv_to_df_kwargs["keep_default_na"] = False - - # More can be found here: - # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.to_sql.html - df_to_sql_kwargs = { - "name": csv_table.table, - "if_exists": form.if_exists.data, - "index": form.index.data, - "index_label": form.index_label.data, - "chunksize": 1000, - } - database.db_engine_spec.create_table_from_csv( + chunks = pd.read_csv( uploaded_tmp_file_path, - csv_table, + chunksize=1000, + encoding="utf-8", + header=form.header.data if form.header.data else 0, + index_col=form.index_col.data, + infer_datetime_format=form.infer_datetime_format.data, + iterator=True, + keep_default_na=not form.null_values.data, + mangle_dupe_cols=form.mangle_dupe_cols.data, + na_values=form.null_values.data if form.null_values.data else None, + nrows=form.nrows.data, + parse_dates=form.parse_dates.data, + sep=form.sep.data, + skip_blank_lines=form.skip_blank_lines.data, + skipinitialspace=form.skipinitialspace.data, + skiprows=form.skiprows.data, + ) + + df = pd.concat(chunks) + + database.db_engine_spec.df_to_sql( database, - csv_to_df_kwargs, - df_to_sql_kwargs, + csv_table, + df, + to_sql_kwargs={ + "chunksize": 1000, + "if_exists": form.if_exists.data, + "index": form.index.data, + "index_label": form.index_label.data, + }, ) # Connect table to the database that should be used for exploration. @@ -321,35 +319,28 @@ def form_post(self, form: ExcelToDatabaseForm) -> Response: db.session.query(models.Database).filter_by(id=con.data.get("id")).one() ) - # some params are not supported by pandas.read_excel (e.g. chunksize). - # More can be found here: - # https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_excel.html - excel_to_df_kwargs = { - "header": form.header.data if form.header.data else 0, - "index_col": form.index_col.data, - "mangle_dupe_cols": form.mangle_dupe_cols.data, - "skiprows": form.skiprows.data, - "nrows": form.nrows.data, - "sheet_name": form.sheet_name.data if form.sheet_name.data else 0, - "parse_dates": form.parse_dates.data, - } - if form.null_values.data: - excel_to_df_kwargs["na_values"] = form.null_values.data - excel_to_df_kwargs["keep_default_na"] = False - - df_to_sql_kwargs = { - "name": excel_table.table, - "if_exists": form.if_exists.data, - "index": form.index.data, - "index_label": form.index_label.data, - "chunksize": 1000, - } - database.db_engine_spec.create_table_from_excel( - uploaded_tmp_file_path, - excel_table, + df = pd.read_excel( + header=form.header.data if form.header.data else 0, + index_col=form.index_col.data, + io=uploaded_tmp_file_path, + keep_default_na=not form.null_values.data, + mangle_dupe_cols=form.mangle_dupe_cols.data, + na_values=form.null_values.data if form.null_values.data else None, + parse_dates=form.parse_dates.data, + skiprows=form.skiprows.data, + sheet_name=form.sheet_name.data if form.sheet_name.data else 0, + ) + + database.db_engine_spec.df_to_sql( database, - excel_to_df_kwargs, - df_to_sql_kwargs, + excel_table, + df, + to_sql_kwargs={ + "chunksize": 1000, + "if_exists": form.if_exists.data, + "index": form.index.data, + "index_label": form.index_label.data, + }, ) # Connect table to the database that should be used for exploration. diff --git a/tests/csv_upload_tests.py b/tests/csv_upload_tests.py index 229a74f17ec15..c12ff44343784 100644 --- a/tests/csv_upload_tests.py +++ b/tests/csv_upload_tests.py @@ -134,13 +134,14 @@ def upload_excel( return get_resp(test_client, "/exceltodatabaseview/form", data=form_data) -def mock_upload_to_s3(f: str, p: str, t: Table) -> str: - """ HDFS is used instead of S3 for the unit tests. +def mock_upload_to_s3(filename: str, upload_prefix: str, table: Table) -> str: + """ + HDFS is used instead of S3 for the unit tests. - :param f: filepath - :param p: unused parameter - :param t: table that will be created - :return: hdfs path to the directory with external table files + :param filename: The file to upload + :param upload_prefix: The S3 prefix + :param table: The table that will be created + :returns: The HDFS path to the directory with external table files """ # only needed for the hive tests import docker @@ -148,11 +149,11 @@ def mock_upload_to_s3(f: str, p: str, t: Table) -> str: client = docker.from_env() container = client.containers.get("namenode") # docker mounted volume that contains csv uploads - src = os.path.join("/tmp/superset_uploads", os.path.basename(f)) + src = os.path.join("/tmp/superset_uploads", os.path.basename(filename)) # hdfs destination for the external tables - dest_dir = os.path.join("/tmp/external/superset_uploads/", str(t)) + dest_dir = os.path.join("/tmp/external/superset_uploads/", str(table)) container.exec_run(f"hdfs dfs -mkdir -p {dest_dir}") - dest = os.path.join(dest_dir, os.path.basename(f)) + dest = os.path.join(dest_dir, os.path.basename(filename)) container.exec_run(f"hdfs dfs -put {src} {dest}") # hive external table expectes a directory for the location return dest_dir @@ -279,23 +280,13 @@ def test_import_csv(setup_csv_upload, create_csv_files): # make sure that john and empty string are replaced with None engine = get_upload_db().get_sqla_engine() data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall() - if utils.backend() == "hive": - # Be aware that hive only uses first value from the null values list. - # It is hive database engine limitation. - # TODO(bkyryliuk): preprocess csv file for hive upload to match default engine capabilities. - assert data == [("john", 1, "x"), ("paul", 2, None)] - else: - assert data == [(None, 1, "x"), ("paul", 2, None)] + assert data == [(None, 1, "x"), ("paul", 2, None)] # default null values upload_csv(CSV_FILENAME2, CSV_UPLOAD_TABLE, extra={"if_exists": "replace"}) # make sure that john and empty string are replaced with None data = engine.execute(f"SELECT * from {CSV_UPLOAD_TABLE}").fetchall() - if utils.backend() == "hive": - # By default hive does not convert values to null vs other databases. - assert data == [("john", 1, "x"), ("paul", 2, "")] - else: - assert data == [("john", 1, "x"), ("paul", 2, None)] + assert data == [("john", 1, "x"), ("paul", 2, None)] @mock.patch("superset.db_engine_specs.hive.upload_to_s3", mock_upload_to_s3) diff --git a/tests/db_engine_specs/bigquery_tests.py b/tests/db_engine_specs/bigquery_tests.py index d0150076087c5..81a9f064429f8 100644 --- a/tests/db_engine_specs/bigquery_tests.py +++ b/tests/db_engine_specs/bigquery_tests.py @@ -23,6 +23,7 @@ from superset.db_engine_specs.base import BaseEngineSpec from superset.db_engine_specs.bigquery import BigQueryEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.sql_parse import Table from tests.db_engine_specs.base_tests import TestDbEngineSpec @@ -166,21 +167,23 @@ def test_normalize_indexes(self): [{"name": "partition", "column_names": ["dttm"], "unique": False}], ) - def test_df_to_sql(self): + @mock.patch("superset.db_engine_specs.bigquery.BigQueryEngineSpec.get_engine") + def test_df_to_sql(self, mock_get_engine): """ DB Eng Specs (bigquery): Test DataFrame to SQL contract """ # test missing google.oauth2 dependency sys.modules["pandas_gbq"] = mock.MagicMock() df = DataFrame() + database = mock.MagicMock() self.assertRaisesRegexp( Exception, "Could not import libraries", BigQueryEngineSpec.df_to_sql, - df, - con="some_connection", - schema="schema", - name="name", + database=database, + table=Table(table="name", schema="schema"), + df=df, + to_sql_kwargs={}, ) invalid_kwargs = [ @@ -191,15 +194,17 @@ def test_df_to_sql(self): {"name": "some_name", "schema": "some_schema"}, {"con": "some_con", "schema": "some_schema"}, ] - # Test check for missing required kwargs (name, schema, con) + # Test check for missing schema. sys.modules["google.oauth2"] = mock.MagicMock() for invalid_kwarg in invalid_kwargs: self.assertRaisesRegexp( Exception, - "name, schema and con need to be defined in kwargs", + "The table schema must be defined", BigQueryEngineSpec.df_to_sql, - df, - **invalid_kwarg, + database=database, + table=Table(table="name"), + df=df, + to_sql_kwargs=invalid_kwarg, ) import pandas_gbq @@ -209,12 +214,15 @@ def test_df_to_sql(self): service_account.Credentials.from_service_account_info = mock.MagicMock( return_value="account_info" ) - connection = mock.Mock() - connection.engine.url.host = "google-host" - connection.dialect.credentials_info = "secrets" + + mock_get_engine.return_value.url.host = "google-host" + mock_get_engine.return_value.dialect.credentials_info = "secrets" BigQueryEngineSpec.df_to_sql( - df, con=connection, schema="schema", name="name", if_exists="extra_key" + database=database, + table=Table(table="name", schema="schema"), + df=df, + to_sql_kwargs={"if_exists": "extra_key"}, ) pandas_gbq.to_gbq.assert_called_with( diff --git a/tests/db_engine_specs/hive_tests.py b/tests/db_engine_specs/hive_tests.py index 71595a0469bf7..b0f7aee7db949 100644 --- a/tests/db_engine_specs/hive_tests.py +++ b/tests/db_engine_specs/hive_tests.py @@ -163,11 +163,10 @@ def test_convert_dttm(): ) -def test_create_table_from_csv_append() -> None: - +def test_df_to_csv() -> None: with pytest.raises(SupersetException): - HiveEngineSpec.create_table_from_csv( - "foo.csv", Table("foobar"), mock.MagicMock(), {}, {"if_exists": "append"} + HiveEngineSpec.df_to_sql( + mock.MagicMock(), Table("foobar"), pd.DataFrame(), {"if_exists": "append"}, ) @@ -176,15 +175,13 @@ def test_create_table_from_csv_append() -> None: {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: True}, ) @mock.patch("superset.db_engine_specs.hive.g", spec={}) -@mock.patch("tableschema.Table") -def test_create_table_from_csv_if_exists_fail(mock_table, mock_g): - mock_table.infer.return_value = {} +def test_df_to_sql_if_exists_fail(mock_g): mock_g.user = True mock_database = mock.MagicMock() mock_database.get_df.return_value.empty = False with pytest.raises(SupersetException, match="Table already exists"): - HiveEngineSpec.create_table_from_csv( - "foo.csv", Table("foobar"), mock_database, {}, {"if_exists": "fail"} + HiveEngineSpec.df_to_sql( + mock_database, Table("foobar"), pd.DataFrame(), {"if_exists": "fail"} ) @@ -193,18 +190,15 @@ def test_create_table_from_csv_if_exists_fail(mock_table, mock_g): {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: True}, ) @mock.patch("superset.db_engine_specs.hive.g", spec={}) -@mock.patch("tableschema.Table") -def test_create_table_from_csv_if_exists_fail_with_schema(mock_table, mock_g): - mock_table.infer.return_value = {} +def test_df_to_sql_if_exists_fail_with_schema(mock_g): mock_g.user = True mock_database = mock.MagicMock() mock_database.get_df.return_value.empty = False with pytest.raises(SupersetException, match="Table already exists"): - HiveEngineSpec.create_table_from_csv( - "foo.csv", - Table(table="foobar", schema="schema"), + HiveEngineSpec.df_to_sql( mock_database, - {}, + Table(table="foobar", schema="schema"), + pd.DataFrame(), {"if_exists": "fail"}, ) @@ -214,11 +208,9 @@ def test_create_table_from_csv_if_exists_fail_with_schema(mock_table, mock_g): {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: True}, ) @mock.patch("superset.db_engine_specs.hive.g", spec={}) -@mock.patch("tableschema.Table") @mock.patch("superset.db_engine_specs.hive.upload_to_s3") -def test_create_table_from_csv_if_exists_replace(mock_upload_to_s3, mock_table, mock_g): +def test_df_to_sql_if_exists_replace(mock_upload_to_s3, mock_g): mock_upload_to_s3.return_value = "mock-location" - mock_table.infer.return_value = {} mock_g.user = True mock_database = mock.MagicMock() mock_database.get_df.return_value.empty = False @@ -226,12 +218,11 @@ def test_create_table_from_csv_if_exists_replace(mock_upload_to_s3, mock_table, mock_database.get_sqla_engine.return_value.execute = mock_execute table_name = "foobar" - HiveEngineSpec.create_table_from_csv( - "foo.csv", - Table(table=table_name), + HiveEngineSpec.df_to_sql( mock_database, - {"sep": "mock", "header": 1, "na_values": "mock"}, - {"if_exists": "replace"}, + Table(table=table_name), + pd.DataFrame(), + {"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"}, ) mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {table_name}") @@ -242,13 +233,9 @@ def test_create_table_from_csv_if_exists_replace(mock_upload_to_s3, mock_table, {**app.config, "CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC": lambda *args: True}, ) @mock.patch("superset.db_engine_specs.hive.g", spec={}) -@mock.patch("tableschema.Table") @mock.patch("superset.db_engine_specs.hive.upload_to_s3") -def test_create_table_from_csv_if_exists_replace_with_schema( - mock_upload_to_s3, mock_table, mock_g -): +def test_df_to_sql_if_exists_replace_with_schema(mock_upload_to_s3, mock_g): mock_upload_to_s3.return_value = "mock-location" - mock_table.infer.return_value = {} mock_g.user = True mock_database = mock.MagicMock() mock_database.get_df.return_value.empty = False @@ -256,84 +243,17 @@ def test_create_table_from_csv_if_exists_replace_with_schema( mock_database.get_sqla_engine.return_value.execute = mock_execute table_name = "foobar" schema = "schema" - HiveEngineSpec.create_table_from_csv( - "foo.csv", - Table(table=table_name, schema=schema), + + HiveEngineSpec.df_to_sql( mock_database, - {"sep": "mock", "header": 1, "na_values": "mock"}, - {"if_exists": "replace"}, + Table(table=table_name, schema=schema), + pd.DataFrame(), + {"if_exists": "replace", "header": 1, "na_values": "mock", "sep": "mock"}, ) mock_execute.assert_any_call(f"DROP TABLE IF EXISTS {schema}.{table_name}") -def test_get_create_table_stmt() -> None: - table = Table("employee") - schema_def = """eid int, name String, salary String, destination String""" - location = "s3a://directory/table" - from unittest import TestCase - - assert HiveEngineSpec.get_create_table_stmt( - table, schema_def, location, ",", 0, [""] - ) == ( - """CREATE TABLE employee ( eid int, name String, salary String, destination String ) - ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim - STORED AS TEXTFILE LOCATION :location - tblproperties ('skip.header.line.count'=:header_line_count, 'serialization.null.format'=:null_value)""", - { - "delim": ",", - "location": "s3a://directory/table", - "header_line_count": "1", - "null_value": "", - }, - ) - assert HiveEngineSpec.get_create_table_stmt( - table, schema_def, location, ",", 1, ["1", "2"] - ) == ( - """CREATE TABLE employee ( eid int, name String, salary String, destination String ) - ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim - STORED AS TEXTFILE LOCATION :location - tblproperties ('skip.header.line.count'=:header_line_count, 'serialization.null.format'=:null_value)""", - { - "delim": ",", - "location": "s3a://directory/table", - "header_line_count": "2", - "null_value": "1", - }, - ) - assert HiveEngineSpec.get_create_table_stmt( - table, schema_def, location, ",", 100, ["NaN"] - ) == ( - """CREATE TABLE employee ( eid int, name String, salary String, destination String ) - ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim - STORED AS TEXTFILE LOCATION :location - tblproperties ('skip.header.line.count'=:header_line_count, 'serialization.null.format'=:null_value)""", - { - "delim": ",", - "location": "s3a://directory/table", - "header_line_count": "101", - "null_value": "NaN", - }, - ) - assert HiveEngineSpec.get_create_table_stmt( - table, schema_def, location, ",", None, None - ) == ( - """CREATE TABLE employee ( eid int, name String, salary String, destination String ) - ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim - STORED AS TEXTFILE LOCATION :location""", - {"delim": ",", "location": "s3a://directory/table"}, - ) - assert HiveEngineSpec.get_create_table_stmt( - table, schema_def, location, ",", 100, [] - ) == ( - """CREATE TABLE employee ( eid int, name String, salary String, destination String ) - ROW FORMAT DELIMITED FIELDS TERMINATED BY :delim - STORED AS TEXTFILE LOCATION :location - tblproperties ('skip.header.line.count'=:header_line_count)""", - {"delim": ",", "location": "s3a://directory/table", "header_line_count": "101"}, - ) - - def test_is_readonly(): def is_readonly(sql: str) -> bool: return HiveEngineSpec.is_readonly_query(ParsedQuery(sql)) From 9f9b26e992244e29eab2fe0130778fbddb6fb854 Mon Sep 17 00:00:00 2001 From: John Bodley Date: Sat, 24 Apr 2021 13:34:41 +1200 Subject: [PATCH 2/2] [csv/excel]: Use stream rather than temporary file --- superset/views/database/views.py | 77 +++++++++++++------------------- 1 file changed, 30 insertions(+), 47 deletions(-) diff --git a/superset/views/database/views.py b/superset/views/database/views.py index fb670c9f9924e..e3c3f9283ff92 100644 --- a/superset/views/database/views.py +++ b/superset/views/database/views.py @@ -150,42 +150,34 @@ def form_post(self, form: CsvToDatabaseForm) -> Response: flash(message, "danger") return redirect("/csvtodatabaseview/form") - uploaded_tmp_file_path = tempfile.NamedTemporaryFile( - dir=app.config["UPLOAD_FOLDER"], - suffix=os.path.splitext(form.csv_file.data.filename)[1].lower(), - delete=False, - ).name - try: - utils.ensure_path_exists(app.config["UPLOAD_FOLDER"]) - upload_stream_write(form.csv_file.data, uploaded_tmp_file_path) - - con = form.data.get("con") - database = ( - db.session.query(models.Database).filter_by(id=con.data.get("id")).one() + df = pd.concat( + pd.read_csv( + chunksize=1000, + encoding="utf-8", + filepath_or_buffer=form.csv_file.data, + header=form.header.data if form.header.data else 0, + index_col=form.index_col.data, + infer_datetime_format=form.infer_datetime_format.data, + iterator=True, + keep_default_na=not form.null_values.data, + mangle_dupe_cols=form.mangle_dupe_cols.data, + na_values=form.null_values.data if form.null_values.data else None, + nrows=form.nrows.data, + parse_dates=form.parse_dates.data, + sep=form.sep.data, + skip_blank_lines=form.skip_blank_lines.data, + skipinitialspace=form.skipinitialspace.data, + skiprows=form.skiprows.data, + ) ) - chunks = pd.read_csv( - uploaded_tmp_file_path, - chunksize=1000, - encoding="utf-8", - header=form.header.data if form.header.data else 0, - index_col=form.index_col.data, - infer_datetime_format=form.infer_datetime_format.data, - iterator=True, - keep_default_na=not form.null_values.data, - mangle_dupe_cols=form.mangle_dupe_cols.data, - na_values=form.null_values.data if form.null_values.data else None, - nrows=form.nrows.data, - parse_dates=form.parse_dates.data, - sep=form.sep.data, - skip_blank_lines=form.skip_blank_lines.data, - skipinitialspace=form.skipinitialspace.data, - skiprows=form.skiprows.data, + database = ( + db.session.query(models.Database) + .filter_by(id=form.data.get("con").data.get("id")) + .one() ) - df = pd.concat(chunks) - database.db_engine_spec.df_to_sql( database, csv_table, @@ -234,10 +226,6 @@ def form_post(self, form: CsvToDatabaseForm) -> Response: db.session.commit() except Exception as ex: # pylint: disable=broad-except db.session.rollback() - try: - os.remove(uploaded_tmp_file_path) - except OSError: - pass message = _( 'Unable to upload CSV file "%(filename)s" to table ' '"%(table_name)s" in database "%(db_name)s". ' @@ -252,7 +240,6 @@ def form_post(self, form: CsvToDatabaseForm) -> Response: stats_logger.incr("failed_csv_upload") return redirect("/csvtodatabaseview/form") - os.remove(uploaded_tmp_file_path) # Go back to welcome page / splash screen message = _( 'CSV file "%(csv_filename)s" uploaded to table "%(table_name)s" in ' @@ -314,15 +301,10 @@ def form_post(self, form: ExcelToDatabaseForm) -> Response: utils.ensure_path_exists(config["UPLOAD_FOLDER"]) upload_stream_write(form.excel_file.data, uploaded_tmp_file_path) - con = form.data.get("con") - database = ( - db.session.query(models.Database).filter_by(id=con.data.get("id")).one() - ) - df = pd.read_excel( header=form.header.data if form.header.data else 0, index_col=form.index_col.data, - io=uploaded_tmp_file_path, + io=form.excel_file.data, keep_default_na=not form.null_values.data, mangle_dupe_cols=form.mangle_dupe_cols.data, na_values=form.null_values.data if form.null_values.data else None, @@ -331,6 +313,12 @@ def form_post(self, form: ExcelToDatabaseForm) -> Response: sheet_name=form.sheet_name.data if form.sheet_name.data else 0, ) + database = ( + db.session.query(models.Database) + .filter_by(id=form.data.get("con").data.get("id")) + .one() + ) + database.db_engine_spec.df_to_sql( database, excel_table, @@ -379,10 +367,6 @@ def form_post(self, form: ExcelToDatabaseForm) -> Response: db.session.commit() except Exception as ex: # pylint: disable=broad-except db.session.rollback() - try: - os.remove(uploaded_tmp_file_path) - except OSError: - pass message = _( 'Unable to upload Excel file "%(filename)s" to table ' '"%(table_name)s" in database "%(db_name)s". ' @@ -397,7 +381,6 @@ def form_post(self, form: ExcelToDatabaseForm) -> Response: stats_logger.incr("failed_excel_upload") return redirect("/exceltodatabaseview/form") - os.remove(uploaded_tmp_file_path) # Go back to welcome page / splash screen message = _( 'Excel file "%(excel_filename)s" uploaded to table "%(table_name)s" in '