diff --git a/superset/queries/saved_queries/api.py b/superset/queries/saved_queries/api.py index e223a9805edf2..aef2288422cb5 100644 --- a/superset/queries/saved_queries/api.py +++ b/superset/queries/saved_queries/api.py @@ -21,14 +21,16 @@ from typing import Any from zipfile import ZipFile -from flask import g, Response, send_file, request +from flask import g, request, Response, send_file from flask_appbuilder.api import expose, protect, rison, safe from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import ngettext -from marshmallow import ValidationError +from superset.commands.exceptions import CommandInvalidError +from superset.commands.importers.v1.utils import get_contents_from_bundle from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod from superset.databases.filters import DatabaseFilter +from superset.extensions import event_logger from superset.models.sql_lab import SavedQuery from superset.queries.saved_queries.commands.bulk_delete import ( BulkDeleteSavedQueryCommand, @@ -36,12 +38,11 @@ from superset.queries.saved_queries.commands.exceptions import ( SavedQueryBulkDeleteFailedError, SavedQueryNotFoundError, - SavedQueryImportError, - SavedQueryImportError, - SavedQueryInvalidError, ) from superset.queries.saved_queries.commands.export import ExportSavedQueriesCommand -from superset.queries.saved_queries.commands.importers.dispatcher import ImportSavedQueriesCommand +from superset.queries.saved_queries.commands.importers.dispatcher import ( + ImportSavedQueriesCommand, +) from superset.queries.saved_queries.filters import ( SavedQueryAllTextFilter, SavedQueryFavoriteFilter, @@ -52,9 +53,6 @@ get_export_ids_schema, openapi_spec_methods_override, ) -from superset.commands.exceptions import CommandInvalidError -from superset.commands.importers.v1.utils import get_contents_from_bundle -from superset.extensions import event_logger from superset.views.base_api import BaseSupersetModelRestApi, statsd_metrics logger = logging.getLogger(__name__) @@ -262,6 +260,7 @@ def export(self, **kwargs: Any) -> Response: as_attachment=True, attachment_filename=filename, ) + @expose("/import/", methods=["POST"]) @protect() @safe @@ -271,7 +270,7 @@ def export(self, **kwargs: Any) -> Response: log_to_statsd=False, ) def import_(self) -> Response: - """Import Saved Queries with associated datasets and databases + """Import Saved Queries with associated databases --- post: requestBody: @@ -289,7 +288,7 @@ def import_(self) -> Response: description: JSON map of passwords for each file type: string overwrite: - description: overwrite existing databases? + description: overwrite existing saved queries? type: bool responses: 200: diff --git a/superset/queries/saved_queries/commands/exceptions.py b/superset/queries/saved_queries/commands/exceptions.py index a69deddd69b4a..731857352444d 100644 --- a/superset/queries/saved_queries/commands/exceptions.py +++ b/superset/queries/saved_queries/commands/exceptions.py @@ -17,10 +17,10 @@ from flask_babel import lazy_gettext as _ from superset.commands.exceptions import ( - CommandException, + CommandException, CommandInvalidError, DeleteFailedError, - ImportFailedError + ImportFailedError, ) @@ -31,9 +31,10 @@ class SavedQueryBulkDeleteFailedError(DeleteFailedError): class SavedQueryNotFoundError(CommandException): message = _("Saved query not found.") + class SavedQueryImportError(ImportFailedError): message = _("Import saved query failed for an unknown reason.") + class SavedQueryInvalidError(CommandInvalidError): message = _("Saved query parameters are invalid.") - \ No newline at end of file diff --git a/superset/queries/saved_queries/commands/importers/dispatcher.py b/superset/queries/saved_queries/commands/importers/dispatcher.py index 14d1fcae62622..a53a765e790b6 100644 --- a/superset/queries/saved_queries/commands/importers/dispatcher.py +++ b/superset/queries/saved_queries/commands/importers/dispatcher.py @@ -30,6 +30,8 @@ command_versions = [ v1.ImportSavedQueriesCommand, ] + + class ImportSavedQueriesCommand(BaseCommand): """ Import Saved Queries @@ -54,7 +56,7 @@ def run(self) -> None: return except IncorrectVersionError: logger.debug("File not handled by command, skipping") - except(CommandInvalidError, ValidationError) as exc: + except (CommandInvalidError, ValidationError) as exc: # found right version, but file is invalid logger.exception("Error running import command") raise exc diff --git a/superset/queries/saved_queries/commands/importers/v1/__init__.py b/superset/queries/saved_queries/commands/importers/v1/__init__.py index c2f8a40b46c4a..41c475c41b237 100644 --- a/superset/queries/saved_queries/commands/importers/v1/__init__.py +++ b/superset/queries/saved_queries/commands/importers/v1/__init__.py @@ -20,22 +20,25 @@ from marshmallow import Schema from sqlalchemy.orm import Session -from superset.queries.saved_queries.commands.exceptions import SavedQueryImportError from superset.commands.importers.v1 import ImportModelsCommand from superset.connectors.sqla.models import SqlaTable from superset.databases.commands.importers.v1.utils import import_database from superset.datasets.commands.importers.v1.utils import import_dataset from superset.datasets.schemas import ImportV1DatasetSchema +from superset.queries.saved_queries.commands.exceptions import SavedQueryImportError +from superset.queries.saved_queries.commands.importers.v1.utils import ( + import_saved_query, +) from superset.queries.saved_queries.dao import SavedQueryDAO -from superset.queries.saved_queries.commands.importers.v1.utils import import_saved_query from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema + class ImportSavedQueriesCommand(ImportModelsCommand): """Import Saved Queries""" dao = SavedQueryDAO - model_name= "saved_queries" - prefix ="saved_queries/" + model_name = "saved_queries" + prefix = "queries/" schemas: Dict[str, Schema] = { "datasets/": ImportV1DatasetSchema(), "queries/": ImportV1SavedQuerySchema(), @@ -59,17 +62,11 @@ def _import( database = import_database(session, config, overwrite=False) database_ids[str(database.uuid)] = database.id - # import saved queries with the correct parent ref for file_name, config in configs.items(): - if file_name.startswith("queries/") and config["database_uuid"] in database: - # update datasource id, type, and name - database = database[config["dataset_uuid"]] - config.update( - { - "datasource_id": database.id, - "datasource_name": database.table_name, - } - ) - config["params"].update({"datasource": database.uid}) + if ( + file_name.startswith("queries/") + and config["database_uuid"] in database_ids + ): + config["db_id"] = database_ids[config["database_uuid"]] import_saved_query(session, config, overwrite=overwrite) diff --git a/superset/queries/saved_queries/commands/importers/v1/utils.py b/superset/queries/saved_queries/commands/importers/v1/utils.py index 651a0b5df8532..f2d090bf11e5b 100644 --- a/superset/queries/saved_queries/commands/importers/v1/utils.py +++ b/superset/queries/saved_queries/commands/importers/v1/utils.py @@ -21,10 +21,11 @@ from superset.models.sql_lab import SavedQuery + def import_saved_query( session: Session, config: Dict[str, Any], overwrite: bool = False ) -> SavedQuery: - existing = session.query(SavedQuery).filter_by(uuid= config["uuid"]).first() + existing = session.query(SavedQuery).filter_by(uuid=config["uuid"]).first() if existing: if not overwrite: return existing diff --git a/superset/queries/saved_queries/schemas.py b/superset/queries/saved_queries/schemas.py index b6386c76e1c50..ca2ef800a67e9 100644 --- a/superset/queries/saved_queries/schemas.py +++ b/superset/queries/saved_queries/schemas.py @@ -14,8 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from flask_babel import lazy_gettext as _ -from marshmallow import fields, Schema, ValidationError +from marshmallow import fields, Schema from marshmallow.validate import Length openapi_spec_methods_override = { @@ -36,11 +35,12 @@ get_delete_ids_schema = {"type": "array", "items": {"type": "integer"}} get_export_ids_schema = {"type": "array", "items": {"type": "integer"}} + class ImportV1SavedQuerySchema(Schema): - schema = fields.String(allow_none=True, validate= Length(0, 128)) - label = fields.String(allow_none=True, validate= Length(0,256)) - description = fields.String(allow_none = True) - sql = fields.String(required= True) + schema = fields.String(allow_none=True, validate=Length(0, 128)) + label = fields.String(allow_none=True, validate=Length(0, 256)) + description = fields.String(allow_none=True) + sql = fields.String(required=True) uuid = fields.UUID(required=True) version = fields.String(required=True) database_uuid = fields.UUID(required=True) diff --git a/tests/fixtures/importexport.py b/tests/fixtures/importexport.py index 71ffbe1285bf9..89391c3db0d7e 100644 --- a/tests/fixtures/importexport.py +++ b/tests/fixtures/importexport.py @@ -346,7 +346,7 @@ saved_queries_metadata_config: Dict[str, Any] = { "version": "1.0.0", "type": "SavedQuery", - "timestamp": "2021-03-30T20:37:54.791187+00:00" + "timestamp": "2021-03-30T20:37:54.791187+00:00", } database_config: Dict[str, Any] = { "allow_csv_upload": True, @@ -510,5 +510,5 @@ "sql": "-- Note: Unless you save your query, these tabs will NOT persist if you clear\nyour cookies or change browsers.\n\n\nSELECT * from birth_names", "uuid": "05b679b5-8eaf-452c-b874-a7a774cfa4e9", "version": "1.0.0", - "database_uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89" + "database_uuid": "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89", } diff --git a/tests/queries/saved_queries/api_tests.py b/tests/queries/saved_queries/api_tests.py index 1983a06873924..dff845cb56f75 100644 --- a/tests/queries/saved_queries/api_tests.py +++ b/tests/queries/saved_queries/api_tests.py @@ -761,12 +761,14 @@ def create_saved_query_import(self): "saved_query_export/databases/imported_database.yaml", "w" ) as fp: fp.write(yaml.safe_dump(database_config).encode()) - with bundle.open("saved_query_export/queries/imported_database/public/imported_saved_query.yaml", "w") as fp: + with bundle.open( + "saved_query_export/queries/imported_database/public/imported_saved_query.yaml", + "w", + ) as fp: fp.write(yaml.safe_dump(saved_queries_config).encode()) buf.seek(0) return buf - @pytest.mark.usefixtures("create_saved_queries") def test_import_saved_queries(self): """ Saved Query API: Test import @@ -791,8 +793,8 @@ def test_import_saved_queries(self): assert len(database.tables) == 1 saved_query = ( - db.session - .query(SavedQuery) - .filter_by(uuid=saved_queries_config["uuid"]).one() + db.session.query(SavedQuery) + .filter_by(uuid=saved_queries_config["uuid"]) + .one() ) assert saved_query.database == database diff --git a/tests/queries/saved_queries/commands_tests.py b/tests/queries/saved_queries/commands_tests.py index fa9a422b2d000..b29484f6c704c 100644 --- a/tests/queries/saved_queries/commands_tests.py +++ b/tests/queries/saved_queries/commands_tests.py @@ -21,15 +21,15 @@ import yaml from superset import db, security_manager -from superset.queries.saved_queries.commands.importers.v1 import ( - ImportSavedQueriesCommand -) from superset.commands.exceptions import CommandInvalidError from superset.commands.importers.exceptions import IncorrectVersionError -from superset.models.sql_lab import SavedQuery from superset.models.core import Database +from superset.models.sql_lab import SavedQuery from superset.queries.saved_queries.commands.exceptions import SavedQueryNotFoundError from superset.queries.saved_queries.commands.export import ExportSavedQueriesCommand +from superset.queries.saved_queries.commands.importers.v1 import ( + ImportSavedQueriesCommand, +) from superset.utils.core import get_example_database from tests.base_tests import SupersetTestCase from tests.fixtures.importexport import ( @@ -39,6 +39,7 @@ saved_queries_metadata_config, ) + class TestExportSavedQueriesCommand(SupersetTestCase): def setUp(self): self.example_database = get_example_database() @@ -120,33 +121,26 @@ def test_export_query_command_key_order(self, mock_g): "version", "database_uuid", ] + + class TestImportSavedQueriesCommand(SupersetTestCase): def test_import_v1_saved_queries(self): """Test that we can import a saved query""" contents = { "metadata.yaml": yaml.safe_dump(saved_queries_metadata_config), "databases/imported_database.yaml": yaml.safe_dump(database_config), - "queries/imported_query.yaml": yaml.safe_dump(saved_queries_config) + "queries/imported_query.yaml": yaml.safe_dump(saved_queries_config), } command = ImportSavedQueriesCommand(contents) command.run() - saved_query = db.session.query(SavedQuery).filter_by( - uuid=saved_queries_config["uuid"] - ).one() - assert saved_query.schema == "public" - assert saved_query.sql == ( - """ - -- Note: Unless you save your query, - these tabs will NOT persist if you clear - your cookies or change browsers. - - SELECT * from birth_names - """ + saved_query = ( + db.session.query(SavedQuery) + .filter_by(uuid=saved_queries_config["uuid"]) + .one() ) - assert saved_query.uuid == "05b679b5-8eaf-452c-b874-a7a774cfa4e9" - assert saved_query.database_uuid == "b8a1ccd3-779d-4ab7-8ad8-9ab119d7fe89" + assert saved_query.schema == "public" database = ( db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() @@ -155,12 +149,13 @@ def test_import_v1_saved_queries(self): db.session.delete(saved_query) db.session.delete(database) db.session.commit() + def test_import_v1_saved_queries_multiple(self): """Test that a saved query can be imported multiple times""" contents = { "metadata.yaml": yaml.safe_dump(saved_queries_metadata_config), "databases/imported_database.yaml": yaml.safe_dump(database_config), - "queries/imported_query.yaml": yaml.safe_dump(saved_queries_config) + "queries/imported_query.yaml": yaml.safe_dump(saved_queries_config), } command = ImportSavedQueriesCommand(contents, overwrite=True) command.run() @@ -168,27 +163,28 @@ def test_import_v1_saved_queries_multiple(self): database = ( db.session.query(SavedQuery).filter_by(uuid=database_config["uuid"]).one() ) - saved_query = db.session.query(SavedQuery).filter_by(datasource_id=database.id).all() + saved_query = ( + db.session.query(SavedQuery).filter_by(datasource_id=database.id).all() + ) assert len(saved_query) == 1 db.session.delete(saved_query[0]) db.session.delete(database) db.session.commit() - + def test_import_v1_saved_queries_validation(self): - """Test different validations applied when importing a chart""" + """Test different validations applied when importing a saved query""" # metadata.yaml must be present contents = { - "metadata.yaml": yaml.safe_dump(saved_queries_metadata_config), "databases/imported_database.yaml": yaml.safe_dump(database_config), - "queries/imported_query.yaml": yaml.safe_dump(saved_queries_config) + "queries/imported_query.yaml": yaml.safe_dump(saved_queries_config), } command = ImportSavedQueriesCommand(contents) with pytest.raises(IncorrectVersionError) as excinfo: command.run() assert str(excinfo.value) == "Missing metadata.yaml" - #version should be 1.0.0 + # version should be 1.0.0 contents["metadata.yaml"] = yaml.safe_dump( { "version": "2.0.0", @@ -201,7 +197,7 @@ def test_import_v1_saved_queries_validation(self): command.run() assert str(excinfo.value) == "Must be equal to 1.0.0" - #type should be a SavedQuery + # type should be a SavedQuery contents["metadata.yaml"] = yaml.safe_dump(database_metadata_config) command = ImportSavedQueriesCommand(contents) with pytest.raises(CommandInvalidError) as excinfo: @@ -211,7 +207,7 @@ def test_import_v1_saved_queries_validation(self): "metadata.yaml": {"type": ["Must be equal to SavedQuery."]} } - # must also validate databases + # must also validate databases broken_config = database_config.copy() del broken_config["database_name"] contents["metadata.yaml"] = yaml.safe_dump(saved_queries_metadata_config) @@ -219,10 +215,9 @@ def test_import_v1_saved_queries_validation(self): command = ImportSavedQueriesCommand(contents) with pytest.raises(CommandInvalidError) as excinfo: command.run() - assert str(excinfo.value) = "Error importing saved query." + assert str(excinfo.value) == "Error importing saved query." assert excinfo.value.normalized_messages() == { "databases/imported_database.yaml": { "database_name": ["Missing data for required field."], } } - \ No newline at end of file