-
-
Notifications
You must be signed in to change notification settings - Fork 362
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1022 from centerofci/function-api-denuked
Introduce function API, without affecting existing APIs
- Loading branch information
Showing
20 changed files
with
728 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,228 @@ | ||
""" | ||
This namespace defines the DBFunction abstract class and its subclasses. These subclasses | ||
represent functions that have identifiers, display names and hints, and their instances | ||
hold parameters. Each DBFunction subclass defines how its instance can be converted into an | ||
SQLAlchemy expression. | ||
Hints hold information about what kind of input the function might expect and what output | ||
can be expected from it. This is used to provide interface information without constraining its | ||
user. | ||
These classes might be used, for example, to define a filter for an SQL query, or to | ||
access hints on what composition of functions and parameters should be valid. | ||
""" | ||
|
||
from abc import ABC, abstractmethod | ||
|
||
from sqlalchemy import column, not_, and_, or_, func, literal | ||
|
||
from db.functions import hints | ||
|
||
|
||
class DBFunction(ABC): | ||
id = None | ||
name = None | ||
hints = None | ||
|
||
# Optionally lists the SQL functions this DBFunction depends on. | ||
# Will be checked against SQL functions defined on a database to tell if it | ||
# supports this DBFunction. Either None or a tuple of SQL function name | ||
# strings. | ||
depends_on = None | ||
|
||
def __init__(self, parameters): | ||
if self.id is None: | ||
raise ValueError('DBFunction subclasses must define an ID.') | ||
if self.name is None: | ||
raise ValueError('DBFunction subclasses must define a name.') | ||
if self.depends_on is not None and not isinstance(self.depends_on, tuple): | ||
raise ValueError('DBFunction subclasses\' depends_on attribute must either be None or a tuple of SQL function names.') | ||
self.parameters = parameters | ||
|
||
@property | ||
def referenced_columns(self): | ||
"""Walks the expression tree, collecting referenced columns. | ||
Useful when checking if all referenced columns are present in the queried relation.""" | ||
columns = set([]) | ||
for parameter in self.parameters: | ||
if isinstance(parameter, ColumnReference): | ||
columns.add(parameter.column) | ||
elif isinstance(parameter, DBFunction): | ||
columns.update(parameter.referenced_columns) | ||
return columns | ||
|
||
@staticmethod | ||
@abstractmethod | ||
def to_sa_expression(): | ||
return None | ||
|
||
|
||
class Literal(DBFunction): | ||
id = 'literal' | ||
name = 'Literal' | ||
hints = tuple([ | ||
hints.parameter_count(1), | ||
hints.parameter(1, hints.literal), | ||
]) | ||
|
||
@staticmethod | ||
def to_sa_expression(primitive): | ||
return literal(primitive) | ||
|
||
|
||
class ColumnReference(DBFunction): | ||
id = 'column_reference' | ||
name = 'Column Reference' | ||
hints = tuple([ | ||
hints.parameter_count(1), | ||
hints.parameter(1, hints.column), | ||
]) | ||
|
||
@property | ||
def column(self): | ||
return self.parameters[0] | ||
|
||
@staticmethod | ||
def to_sa_expression(column_name): | ||
return column(column_name) | ||
|
||
|
||
class List(DBFunction): | ||
id = 'list' | ||
name = 'List' | ||
|
||
@staticmethod | ||
def to_sa_expression(*items): | ||
return list(items) | ||
|
||
|
||
class Empty(DBFunction): | ||
id = 'empty' | ||
name = 'Empty' | ||
hints = tuple([ | ||
hints.returns(hints.boolean), | ||
hints.parameter_count(1), | ||
]) | ||
|
||
@staticmethod | ||
def to_sa_expression(value): | ||
return value.is_(None) | ||
|
||
|
||
class Not(DBFunction): | ||
id = 'not' | ||
name = 'Not' | ||
hints = tuple([ | ||
hints.returns(hints.boolean), | ||
hints.parameter_count(1), | ||
]) | ||
|
||
@staticmethod | ||
def to_sa_expression(value): | ||
return not_(value) | ||
|
||
|
||
class Equal(DBFunction): | ||
id = 'equal' | ||
name = 'Equal' | ||
hints = tuple([ | ||
hints.returns(hints.boolean), | ||
hints.parameter_count(2), | ||
]) | ||
|
||
@staticmethod | ||
def to_sa_expression(value1, value2): | ||
return value1 == value2 | ||
|
||
|
||
class Greater(DBFunction): | ||
id = 'greater' | ||
name = 'Greater' | ||
hints = tuple([ | ||
hints.returns(hints.boolean), | ||
hints.parameter_count(2), | ||
hints.all_parameters(hints.comparable), | ||
]) | ||
|
||
@staticmethod | ||
def to_sa_expression(value1, value2): | ||
return value1 > value2 | ||
|
||
|
||
class Lesser(DBFunction): | ||
id = 'lesser' | ||
name = 'Lesser' | ||
hints = tuple([ | ||
hints.returns(hints.boolean), | ||
hints.parameter_count(2), | ||
hints.all_parameters(hints.comparable), | ||
]) | ||
|
||
@staticmethod | ||
def to_sa_expression(value1, value2): | ||
return value1 < value2 | ||
|
||
|
||
class In(DBFunction): | ||
id = 'in' | ||
name = 'In' | ||
hints = tuple([ | ||
hints.returns(hints.boolean), | ||
hints.parameter_count(2), | ||
hints.parameter(2, hints.array), | ||
]) | ||
|
||
@staticmethod | ||
def to_sa_expression(value1, value2): | ||
return value1.in_(value2) | ||
|
||
|
||
class And(DBFunction): | ||
id = 'and' | ||
name = 'And' | ||
hints = tuple([ | ||
hints.returns(hints.boolean), | ||
]) | ||
|
||
@staticmethod | ||
def to_sa_expression(*values): | ||
return and_(*values) | ||
|
||
|
||
class Or(DBFunction): | ||
id = 'or' | ||
name = 'Or' | ||
hints = tuple([ | ||
hints.returns(hints.boolean), | ||
]) | ||
|
||
@staticmethod | ||
def to_sa_expression(*values): | ||
return or_(*values) | ||
|
||
|
||
class StartsWith(DBFunction): | ||
id = 'starts_with' | ||
name = 'Starts With' | ||
hints = tuple([ | ||
hints.returns(hints.boolean), | ||
hints.parameter_count(2), | ||
hints.all_parameters(hints.string_like), | ||
]) | ||
|
||
@staticmethod | ||
def to_sa_expression(string, prefix): | ||
return string.like(f'{prefix}%') | ||
|
||
|
||
class ToLowercase(DBFunction): | ||
id = 'to_lowercase' | ||
name = 'To Lowercase' | ||
hints = tuple([ | ||
hints.parameter_count(1), | ||
hints.all_parameters(hints.string_like), | ||
]) | ||
|
||
@staticmethod | ||
def to_sa_expression(string): | ||
return func.lower(string) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
class BadDBFunctionFormat(Exception): | ||
pass | ||
|
||
|
||
class UnknownDBFunctionId(BadDBFunctionFormat): | ||
pass | ||
|
||
|
||
class ReferencedColumnsDontExist(BadDBFunctionFormat): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from frozendict import frozendict | ||
|
||
|
||
def _make_hint(id, **rest): | ||
return frozendict({"id": id, **rest}) | ||
|
||
|
||
def parameter_count(count): | ||
return _make_hint("parameter_count", count=count) | ||
|
||
|
||
def parameter(index, *hints): | ||
return _make_hint("parameter", index=index, hints=hints) | ||
|
||
|
||
def all_parameters(*hints): | ||
return _make_hint("all_parameters", hints=hints) | ||
|
||
|
||
def returns(*hints): | ||
return _make_hint("returns", hints=hints) | ||
|
||
|
||
boolean = _make_hint("boolean") | ||
|
||
|
||
comparable = _make_hint("comparable") | ||
|
||
|
||
column = _make_hint("column") | ||
|
||
|
||
array = _make_hint("array") | ||
|
||
|
||
string_like = _make_hint("string_like") | ||
|
||
|
||
uri = _make_hint("uri") | ||
|
||
|
||
literal = _make_hint("literal") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
""" | ||
Exports the known_db_functions variable, which describes what `DBFunction`s the library is aware | ||
of. Note, that a `DBFunction` might be in this collection, but not be supported by a given | ||
database. | ||
Contains a private collection (`_db_functions_in_other_modules`) of `DBFunction` subclasses | ||
declared outside the base module. | ||
These variables were broken off into a discrete module to avoid circular imports. | ||
""" | ||
|
||
import inspect | ||
|
||
import db.functions.base | ||
|
||
from db.functions.base import DBFunction | ||
|
||
from db.types import uri | ||
|
||
|
||
def _get_module_members_that_satisfy(module, predicate): | ||
""" | ||
Looks at the members of the provided module and filters them using the provided predicate. | ||
In this context, it (together with the appropriate predicate) is used to automatically collect | ||
all DBFunction subclasses found as top-level members of a module. | ||
""" | ||
all_members_in_defining_module = inspect.getmembers(module) | ||
return tuple( | ||
member | ||
for _, member in all_members_in_defining_module | ||
if predicate(member) | ||
) | ||
|
||
|
||
def _is_concrete_db_function_subclass(member): | ||
return ( | ||
inspect.isclass(member) | ||
and member != DBFunction | ||
and issubclass(member, DBFunction) | ||
) | ||
|
||
|
||
_db_functions_in_base_module = ( | ||
_get_module_members_that_satisfy( | ||
db.functions.base, | ||
_is_concrete_db_function_subclass | ||
) | ||
) | ||
|
||
|
||
_db_functions_in_other_modules = tuple([ | ||
uri.ExtractURIAuthority, | ||
]) | ||
|
||
|
||
known_db_functions = _db_functions_in_base_module + _db_functions_in_other_modules |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from db.functions.base import DBFunction | ||
from db.functions.exceptions import ReferencedColumnsDontExist | ||
from db.functions.operations.deserialize import get_db_function_from_ma_function_spec | ||
|
||
|
||
def apply_ma_function_spec_as_filter(relation, ma_function_spec): | ||
db_function = get_db_function_from_ma_function_spec(ma_function_spec) | ||
return apply_db_function_as_filter(relation, db_function) | ||
|
||
|
||
def apply_db_function_as_filter(relation, db_function): | ||
_assert_that_all_referenced_columns_exist(relation, db_function) | ||
sa_expression = _db_function_to_sa_expression(db_function) | ||
relation = relation.filter(sa_expression) | ||
return relation | ||
|
||
|
||
def _assert_that_all_referenced_columns_exist(relation, db_function): | ||
columns_that_exist = _get_columns_that_exist(relation) | ||
referenced_columns = db_function.referenced_columns | ||
referenced_columns_that_dont_exist = \ | ||
set.difference(referenced_columns, columns_that_exist) | ||
if len(referenced_columns_that_dont_exist) > 0: | ||
raise ReferencedColumnsDontExist( | ||
"These referenced columns don't exist on the relevant relation: " | ||
+ f"{referenced_columns_that_dont_exist}" | ||
) | ||
|
||
|
||
def _get_columns_that_exist(relation): | ||
columns = relation.selected_columns | ||
return set(column.name for column in columns) | ||
|
||
|
||
def _db_function_to_sa_expression(db_function): | ||
""" | ||
Takes a DBFunction, looks at the tree of its parameters (and the parameters of nested | ||
DBFunctions), and turns it into an SQLAlchemy expression. Each parameter is expected to either | ||
be a DBFunction instance or a literal primitive. | ||
""" | ||
if isinstance(db_function, DBFunction): | ||
raw_parameters = db_function.parameters | ||
parameters = [ | ||
_db_function_to_sa_expression(raw_parameter) | ||
for raw_parameter in raw_parameters | ||
] | ||
db_function_subclass = type(db_function) | ||
return db_function_subclass.to_sa_expression(*parameters) | ||
else: | ||
return db_function |
Oops, something went wrong.