From a3854440624aed4a4ff7486499b489e246fb4c50 Mon Sep 17 00:00:00 2001 From: Timi Fasubaa Date: Mon, 26 Mar 2018 19:35:09 -0700 Subject: [PATCH] move get_or_create_main_db to utils --- superset/cli.py | 4 ++-- superset/data/__init__.py | 24 ++++++++++++------------ superset/security.py | 20 -------------------- superset/utils.py | 20 ++++++++++++++++++++ tests/base_tests.py | 3 ++- 5 files changed, 36 insertions(+), 35 deletions(-) diff --git a/superset/cli.py b/superset/cli.py index 33c835fd51349..46e9d81714800 100755 --- a/superset/cli.py +++ b/superset/cli.py @@ -16,7 +16,7 @@ from pathlib2 import Path import yaml -from superset import app, db, dict_import_export_util, sm, utils +from superset import app, data, db, dict_import_export_util, sm, utils config = app.config celery_app = utils.get_celery_app(config) @@ -28,6 +28,7 @@ @manager.command def init(): """Inits the Superset application""" + utils.get_or_create_main_db() sm.sync_role_definitions() @@ -108,7 +109,6 @@ def version(verbose): help='Load additional test data') def load_examples(load_test_data): """Loads a set of Slices and Dashboards and a supporting dataset """ - from superset import data print('Loading examples into {}'.format(db)) data.load_css_templates() diff --git a/superset/data/__init__.py b/superset/data/__init__.py index ffcee0f638fa2..5ce07463a017f 100644 --- a/superset/data/__init__.py +++ b/superset/data/__init__.py @@ -70,7 +70,7 @@ def load_energy(): if not tbl: tbl = TBL(table_name=tbl_name) tbl.description = "Energy consumption" - tbl.database = sm.get_or_create_main_db() + tbl.database = utils.get_or_create_main_db() db.session.merge(tbl) db.session.commit() tbl.fetch_metadata() @@ -178,7 +178,7 @@ def load_world_bank_health_n_pop(): tbl = TBL(table_name=tbl_name) tbl.description = utils.readfile(os.path.join(DATA_FOLDER, 'countries.md')) tbl.main_dttm_col = 'year' - tbl.database = sm.get_or_create_main_db() + tbl.database = utils.get_or_create_main_db() tbl.filter_select_enabled = True db.session.merge(tbl) db.session.commit() @@ -582,7 +582,7 @@ def load_birth_names(): if not obj: obj = TBL(table_name='birth_names') obj.main_dttm_col = 'ds' - obj.database = sm.get_or_create_main_db() + obj.database = utils.get_or_create_main_db() obj.filter_select_enabled = True db.session.merge(obj) db.session.commit() @@ -869,7 +869,7 @@ def load_unicode_test_data(): if not obj: obj = TBL(table_name='unicode_test') obj.main_dttm_col = 'dttm' - obj.database = sm.get_or_create_main_db() + obj.database = utils.get_or_create_main_db() db.session.merge(obj) db.session.commit() obj.fetch_metadata() @@ -947,7 +947,7 @@ def load_random_time_series_data(): if not obj: obj = TBL(table_name='random_time_series') obj.main_dttm_col = 'ds' - obj.database = sm.get_or_create_main_db() + obj.database = utils.get_or_create_main_db() db.session.merge(obj) db.session.commit() obj.fetch_metadata() @@ -1010,7 +1010,7 @@ def load_country_map_data(): if not obj: obj = TBL(table_name='birth_france_by_region') obj.main_dttm_col = 'dttm' - obj.database = sm.get_or_create_main_db() + obj.database = utils.get_or_create_main_db() db.session.merge(obj) db.session.commit() obj.fetch_metadata() @@ -1085,7 +1085,7 @@ def load_long_lat_data(): if not obj: obj = TBL(table_name='long_lat') obj.main_dttm_col = 'datetime' - obj.database = sm.get_or_create_main_db() + obj.database = utils.get_or_create_main_db() db.session.merge(obj) db.session.commit() obj.fetch_metadata() @@ -1146,7 +1146,7 @@ def load_multiformat_time_series_data(): if not obj: obj = TBL(table_name='multiformat_time_series') obj.main_dttm_col = 'ds' - obj.database = sm.get_or_create_main_db() + obj.database = utils.get_or_create_main_db() dttm_and_expr_dict = { 'ds': [None, None], 'ds2': [None, None], @@ -1768,7 +1768,7 @@ def load_flights(): if not tbl: tbl = TBL(table_name=tbl_name) tbl.description = "Random set of flights in the US" - tbl.database = sm.get_or_create_main_db() + tbl.database = utils.get_or_create_main_db() db.session.merge(tbl) db.session.commit() tbl.fetch_metadata() @@ -1799,7 +1799,7 @@ def load_paris_iris_geojson(): if not tbl: tbl = TBL(table_name=tbl_name) tbl.description = "Map of Paris" - tbl.database = sm.get_or_create_main_db() + tbl.database = utils.get_or_create_main_db() db.session.merge(tbl) db.session.commit() tbl.fetch_metadata() @@ -1829,7 +1829,7 @@ def load_sf_population_polygons(): if not tbl: tbl = TBL(table_name=tbl_name) tbl.description = "Population density of San Francisco" - tbl.database = sm.get_or_create_main_db() + tbl.database = utils.get_or_create_main_db() db.session.merge(tbl) db.session.commit() tbl.fetch_metadata() @@ -1859,7 +1859,7 @@ def load_bart_lines(): if not tbl: tbl = TBL(table_name=tbl_name) tbl.description = "BART lines" - tbl.database = sm.get_or_create_main_db() + tbl.database = utils.get_or_create_main_db() db.session.merge(tbl) db.session.commit() tbl.fetch_metadata() diff --git a/superset/security.py b/superset/security.py index a84911965b8a4..8ea91b941561e 100644 --- a/superset/security.py +++ b/superset/security.py @@ -290,7 +290,6 @@ def sync_role_definitions(self): from superset import conf logging.info('Syncing role definition') - self.get_or_create_main_db() self.create_custom_permissions() # Creating default roles @@ -309,25 +308,6 @@ def sync_role_definitions(self): self.get_session.commit() self.clean_perms() - def get_or_create_main_db(self): - from superset import conf, db - from superset.models import core as models - - logging.info('Creating database reference') - dbobj = ( - db.session.query(models.Database) - .filter_by(database_name='main') - .first() - ) - if not dbobj: - dbobj = models.Database(database_name='main') - dbobj.set_sqlalchemy_uri(conf.get('SQLALCHEMY_DATABASE_URI')) - dbobj.expose_in_sqllab = True - dbobj.allow_run_sync = True - db.session.add(dbobj) - db.session.commit() - return dbobj - def set_role(self, role_name, pvm_check): logging.info('Syncing {} perms'.format(role_name)) sesh = self.get_session diff --git a/superset/utils.py b/superset/utils.py index 4c8dee46d7fd5..8e91e8dc80406 100644 --- a/superset/utils.py +++ b/superset/utils.py @@ -822,3 +822,23 @@ def user_label(user): return user.first_name + ' ' + user.last_name else: return user.username + + +def get_or_create_main_db(): + from superset import conf, db + from superset.models import core as models + + logging.info('Creating database reference') + dbobj = ( + db.session.query(models.Database) + .filter_by(database_name='main') + .first() + ) + if not dbobj: + dbobj = models.Database(database_name='main') + dbobj.set_sqlalchemy_uri(conf.get('SQLALCHEMY_DATABASE_URI')) + dbobj.expose_in_sqllab = True + dbobj.allow_run_sync = True + db.session.add(dbobj) + db.session.commit() + return dbobj diff --git a/tests/base_tests.py b/tests/base_tests.py index 29acc79776b31..aac5b42d143e9 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -12,7 +12,7 @@ from flask_appbuilder.security.sqla import models as ab_models -from superset import app, appbuilder, cli, db, sm +from superset import app, appbuilder, cli, db, sm, utils from superset.connectors.druid.models import DruidCluster, DruidDatasource from superset.connectors.sqla.models import SqlaTable from superset.models import core as models @@ -46,6 +46,7 @@ def __init__(self, *args, **kwargs): gamma_sqllab_role = sm.add_role('gamma_sqllab') for perm in sm.find_role('Gamma').permissions: sm.add_permission_role(gamma_sqllab_role, perm) + utils.get_or_create_main_db() db_perm = self.get_main_database(sm.get_session).perm sm.merge_perm('database_access', db_perm) db_pvm = sm.find_permission_view_menu(