Skip to content

Commit

Permalink
move get_or_create_main_db to utils
Browse files Browse the repository at this point in the history
  • Loading branch information
timifasubaa committed Mar 27, 2018
1 parent 7c32934 commit a385444
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 35 deletions.
4 changes: 2 additions & 2 deletions superset/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from pathlib2 import Path
import yaml

from superset import app, db, dict_import_export_util, sm, utils
from superset import app, data, db, dict_import_export_util, sm, utils

config = app.config
celery_app = utils.get_celery_app(config)
Expand All @@ -28,6 +28,7 @@
@manager.command
def init():
"""Inits the Superset application"""
utils.get_or_create_main_db()
sm.sync_role_definitions()


Expand Down Expand Up @@ -108,7 +109,6 @@ def version(verbose):
help='Load additional test data')
def load_examples(load_test_data):
"""Loads a set of Slices and Dashboards and a supporting dataset """
from superset import data
print('Loading examples into {}'.format(db))

data.load_css_templates()
Expand Down
24 changes: 12 additions & 12 deletions superset/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def load_energy():
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = "Energy consumption"
tbl.database = sm.get_or_create_main_db()
tbl.database = utils.get_or_create_main_db()
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
Expand Down Expand Up @@ -178,7 +178,7 @@ def load_world_bank_health_n_pop():
tbl = TBL(table_name=tbl_name)
tbl.description = utils.readfile(os.path.join(DATA_FOLDER, 'countries.md'))
tbl.main_dttm_col = 'year'
tbl.database = sm.get_or_create_main_db()
tbl.database = utils.get_or_create_main_db()
tbl.filter_select_enabled = True
db.session.merge(tbl)
db.session.commit()
Expand Down Expand Up @@ -582,7 +582,7 @@ def load_birth_names():
if not obj:
obj = TBL(table_name='birth_names')
obj.main_dttm_col = 'ds'
obj.database = sm.get_or_create_main_db()
obj.database = utils.get_or_create_main_db()
obj.filter_select_enabled = True
db.session.merge(obj)
db.session.commit()
Expand Down Expand Up @@ -869,7 +869,7 @@ def load_unicode_test_data():
if not obj:
obj = TBL(table_name='unicode_test')
obj.main_dttm_col = 'dttm'
obj.database = sm.get_or_create_main_db()
obj.database = utils.get_or_create_main_db()
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
Expand Down Expand Up @@ -947,7 +947,7 @@ def load_random_time_series_data():
if not obj:
obj = TBL(table_name='random_time_series')
obj.main_dttm_col = 'ds'
obj.database = sm.get_or_create_main_db()
obj.database = utils.get_or_create_main_db()
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
Expand Down Expand Up @@ -1010,7 +1010,7 @@ def load_country_map_data():
if not obj:
obj = TBL(table_name='birth_france_by_region')
obj.main_dttm_col = 'dttm'
obj.database = sm.get_or_create_main_db()
obj.database = utils.get_or_create_main_db()
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
Expand Down Expand Up @@ -1085,7 +1085,7 @@ def load_long_lat_data():
if not obj:
obj = TBL(table_name='long_lat')
obj.main_dttm_col = 'datetime'
obj.database = sm.get_or_create_main_db()
obj.database = utils.get_or_create_main_db()
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
Expand Down Expand Up @@ -1146,7 +1146,7 @@ def load_multiformat_time_series_data():
if not obj:
obj = TBL(table_name='multiformat_time_series')
obj.main_dttm_col = 'ds'
obj.database = sm.get_or_create_main_db()
obj.database = utils.get_or_create_main_db()
dttm_and_expr_dict = {
'ds': [None, None],
'ds2': [None, None],
Expand Down Expand Up @@ -1768,7 +1768,7 @@ def load_flights():
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = "Random set of flights in the US"
tbl.database = sm.get_or_create_main_db()
tbl.database = utils.get_or_create_main_db()
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
Expand Down Expand Up @@ -1799,7 +1799,7 @@ def load_paris_iris_geojson():
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = "Map of Paris"
tbl.database = sm.get_or_create_main_db()
tbl.database = utils.get_or_create_main_db()
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
Expand Down Expand Up @@ -1829,7 +1829,7 @@ def load_sf_population_polygons():
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = "Population density of San Francisco"
tbl.database = sm.get_or_create_main_db()
tbl.database = utils.get_or_create_main_db()
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
Expand Down Expand Up @@ -1859,7 +1859,7 @@ def load_bart_lines():
if not tbl:
tbl = TBL(table_name=tbl_name)
tbl.description = "BART lines"
tbl.database = sm.get_or_create_main_db()
tbl.database = utils.get_or_create_main_db()
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
20 changes: 0 additions & 20 deletions superset/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,6 @@ def sync_role_definitions(self):
from superset import conf
logging.info('Syncing role definition')

self.get_or_create_main_db()
self.create_custom_permissions()

# Creating default roles
Expand All @@ -309,25 +308,6 @@ def sync_role_definitions(self):
self.get_session.commit()
self.clean_perms()

def get_or_create_main_db(self):
from superset import conf, db
from superset.models import core as models

logging.info('Creating database reference')
dbobj = (
db.session.query(models.Database)
.filter_by(database_name='main')
.first()
)
if not dbobj:
dbobj = models.Database(database_name='main')
dbobj.set_sqlalchemy_uri(conf.get('SQLALCHEMY_DATABASE_URI'))
dbobj.expose_in_sqllab = True
dbobj.allow_run_sync = True
db.session.add(dbobj)
db.session.commit()
return dbobj

def set_role(self, role_name, pvm_check):
logging.info('Syncing {} perms'.format(role_name))
sesh = self.get_session
Expand Down
20 changes: 20 additions & 0 deletions superset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,3 +822,23 @@ def user_label(user):
return user.first_name + ' ' + user.last_name
else:
return user.username


def get_or_create_main_db():
from superset import conf, db
from superset.models import core as models

logging.info('Creating database reference')
dbobj = (
db.session.query(models.Database)
.filter_by(database_name='main')
.first()
)
if not dbobj:
dbobj = models.Database(database_name='main')
dbobj.set_sqlalchemy_uri(conf.get('SQLALCHEMY_DATABASE_URI'))
dbobj.expose_in_sqllab = True
dbobj.allow_run_sync = True
db.session.add(dbobj)
db.session.commit()
return dbobj
3 changes: 2 additions & 1 deletion tests/base_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from flask_appbuilder.security.sqla import models as ab_models

from superset import app, appbuilder, cli, db, sm
from superset import app, appbuilder, cli, db, sm, utils
from superset.connectors.druid.models import DruidCluster, DruidDatasource
from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models
Expand Down Expand Up @@ -46,6 +46,7 @@ def __init__(self, *args, **kwargs):
gamma_sqllab_role = sm.add_role('gamma_sqllab')
for perm in sm.find_role('Gamma').permissions:
sm.add_permission_role(gamma_sqllab_role, perm)
utils.get_or_create_main_db()
db_perm = self.get_main_database(sm.get_session).perm
sm.merge_perm('database_access', db_perm)
db_pvm = sm.find_permission_view_menu(
Expand Down

0 comments on commit a385444

Please sign in to comment.