From 1700a807e9a8ebc4cfb2749293308be733e42473 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Wed, 2 Nov 2016 13:22:07 -0700 Subject: [PATCH] [sqllab] templating refactor (#1504) * Add support for jinja templates in WHERE/HAVING clauses * Generalizing * bugfix --- caravel/jinja_context.py | 74 ++++++++++++++++++---------------------- caravel/models.py | 20 +++++++---- caravel/sql_lab.py | 7 ++-- dev-reqs.txt | 1 + tests/core_tests.py | 4 ++- 5 files changed, 55 insertions(+), 51 deletions(-) diff --git a/caravel/jinja_context.py b/caravel/jinja_context.py index f9eae6444e753..95212c9b28b00 100644 --- a/caravel/jinja_context.py +++ b/caravel/jinja_context.py @@ -18,9 +18,18 @@ from caravel.utils import CaravelTemplateException config = app.config +BASE_CONTEXT = { + 'datetime': datetime, + 'random': random, + 'relativedelta': relativedelta, + 'time': time, + 'timedelta': timedelta, + 'uuid': uuid, +} +BASE_CONTEXT.update(config.get('JINJA_CONTEXT_ADDONS', {})) -class BaseContext(object): +class BaseTemplateProcessor(object): """Base class for database-specific jinja context @@ -37,17 +46,31 @@ class BaseContext(object): """ engine = None - def __init__(self, database, query): + def __init__(self, database=None, query=None, table=None): self.database = database self.query = query self.schema = None if query and query.schema: self.schema = query.schema - elif database: - self.schema = database.schema + elif table: + self.schema = table.schema + self.context = {} + self.context.update(BASE_CONTEXT) + if self.engine: + self.context[self.engine] = self + + def process_template(self, sql): + """Processes a sql template + + >>> sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'" + >>> process_template(sql) + "SELECT '2017-01-01T00:00:00'" + """ + template = jinja2.Template(sql) + return template.render(self.context) -class PrestoContext(BaseContext): +class PrestoTemplateProcessor(BaseTemplateProcessor): """Presto Jinja context The methods described here are namespaced under ``presto`` in the @@ -170,43 +193,14 @@ def latest_sub_partition(self, table_name, **kwargs): return df.to_dict()[field_to_return][0] -db_contexes = {} +template_processors = {} keys = tuple(globals().keys()) for k in keys: o = globals()[k] - if o and inspect.isclass(o) and issubclass(o, BaseContext): - db_contexes[o.engine] = o - - -def get_context(engine_name=None): - context = { - 'datetime': datetime, - 'random': random, - 'relativedelta': relativedelta, - 'time': time, - 'timedelta': timedelta, - 'uuid': uuid, - } - db_context = db_contexes.get(engine_name) - if engine_name and db_context: - context[engine_name] = db_context - return context - - -def process_template(sql, database=None, query=None): - """Processes a sql template - - >>> sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'" - >>> process_template(sql) - "SELECT '2017-01-01T00:00:00'" - """ + if o and inspect.isclass(o) and issubclass(o, BaseTemplateProcessor): + template_processors[o.engine] = o - context = get_context(database.backend if database else None) - template = jinja2.Template(sql) - backend = database.backend if database else None - # instantiating only the context for the active database - if context and backend in context: - context[backend] = context[backend](database, query) - context.update(config.get('JINJA_CONTEXT_ADDONS', {})) - return template.render(context) +def get_template_processor(database, table=None, query=None): + TP = template_processors.get(database.backend, BaseTemplateProcessor) + return TP(database=database, table=table, query=query) diff --git a/caravel/models.py b/caravel/models.py index 573de6352a2b7..fe2950181eb14 100644 --- a/caravel/models.py +++ b/caravel/models.py @@ -56,7 +56,7 @@ from caravel import app, db, db_engine_specs, get_session, utils, sm from caravel.source_registry import SourceRegistry from caravel.viz import viz_types -from caravel.jinja_context import process_template +from caravel.jinja_context import get_template_processor from caravel.utils import ( flasher, MetricPermException, DimSelector, wrap_clause_in_parens ) @@ -960,6 +960,9 @@ def query( # sqla extras=None, columns=None): """Querying any sqla table from this common interface""" + template_processor = get_template_processor( + table=self, database=self.database) + # For backward compatibility if granularity not in self.dttm_cols: granularity = self.main_dttm_col @@ -1088,12 +1091,15 @@ def visit_column(element, compiler, **kw): if op == 'not in': cond = ~cond where_clause_and.append(cond) - if extras and 'where' in extras: - where = wrap_clause_in_parens(process_template(extras['where'], self.database)) - where_clause_and += [where] - if extras and 'having' in extras: - having = wrap_clause_in_parens(process_template(extras['having'], self.database)) - having_clause_and += [having] + if extras: + where = extras.get('where') + if where: + where_clause_and += [wrap_clause_in_parens( + template_processor.process_template(where))] + having = extras.get('having') + if having: + having_clause_and += [wrap_clause_in_parens( + template_processor.process_template(having))] if granularity: qry = qry.where(and_(*(time_filter + where_clause_and))) else: diff --git a/caravel/sql_lab.py b/caravel/sql_lab.py index eb9a5cf0f0200..99b1370f41d77 100644 --- a/caravel/sql_lab.py +++ b/caravel/sql_lab.py @@ -13,8 +13,7 @@ from caravel import ( app, db, models, utils, dataframe, results_backend) from caravel.db_engine_specs import LimitMethod -from caravel.jinja_context import process_template - +from caravel.jinja_context import get_template_processor QueryStatus = models.QueryStatus celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG')) @@ -101,7 +100,9 @@ def handle_error(msg): query.limit_used = True engine = database.get_sqla_engine(schema=query.schema) try: - executed_sql = process_template(executed_sql, database, query) + template_processor = get_template_processor( + database=database, query=query) + executed_sql = template_processor.process_template(executed_sql) except Exception as e: logging.exception(e) msg = "Template rendering failed: " + utils.error_msg_from_exception(e) diff --git a/dev-reqs.txt b/dev-reqs.txt index 6f692ec871c23..9938d591360d7 100644 --- a/dev-reqs.txt +++ b/dev-reqs.txt @@ -1,5 +1,6 @@ codeclimate-test-reporter coveralls +flake8 mock mysqlclient nose diff --git a/tests/core_tests.py b/tests/core_tests.py index 1d58dde3df3f0..111f0d351f154 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -358,8 +358,10 @@ def test_extra_table_metadata(self): 'ab_permission_view/panoramix/'.format(**locals())) def test_process_template(self): + maindb = self.get_main_database(db.session) sql = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'" - rendered = jinja_context.process_template(sql) + tp = jinja_context.get_template_processor(database=maindb) + rendered = tp.process_template(sql) self.assertEqual("SELECT '2017-01-01T00:00:00'", rendered) def test_templated_sql_json(self):