Skip to content

Commit

Permalink
Added a .q function for creating a query and linked the select functi…
Browse files Browse the repository at this point in the history
…on to that.
  • Loading branch information
mattilyra committed Nov 9, 2015
1 parent 35d93ef commit a56352c
Showing 1 changed file with 35 additions and 29 deletions.
64 changes: 35 additions & 29 deletions naklar/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,38 @@ def populate_from_disk(cls, files, load_func=None, extra_params=None):
session.close()


def q(*columns, **filters):
cols, filts = [], []
if columns:
for col in columns:
if isinstance(col, six.string_types):
cols.append(getattr(E, col))
elif isinstance(col, InstrumentedAttribute):
cols.append(col)
elif isinstance(col, BinaryExpression):
filts.append(col)

if not cols: # in case all *args were BinaryExpression (filters)
cols = [E]

session = Session(bind=_engine)
q_ = session.query(*cols)

if filters or filts:
for k, v in six.iteritems(filters):
if isinstance(v, BinaryExpression):
filts.append(v)
elif hasattr(v, 'split'):
filts.append(getattr(E, k) == v)
elif hasattr(v, '__getitem__') or hasattr(v, '__iter__'):
filts.append(getattr(E, k).in_(v))
else:
filts.append(getattr(E, k) == v)
q_ = q_.filter(*filts)

return q_, session


def select(*columns, **filters):
"""Get rows from the Experiment table associated with session.
Expand Down Expand Up @@ -436,35 +468,9 @@ def select(*columns, **filters):
>> initialise('.')
>> rows = select('model_type', 'results_file', k=[1, 2, 3, 5, 8])
"""
cols, filts = [], []
if columns:
for col in columns:
if isinstance(col, six.string_types):
cols.append(getattr(E, col))
elif isinstance(col, InstrumentedAttribute):
cols.append(col)
elif isinstance(col, BinaryExpression):
filts.append(col)

if not cols: # in case all *args were BinaryExpression (filters)
cols = [E]

session = Session(bind=_engine)
q = session.query(*cols)

if filters or filts:
for k, v in six.iteritems(filters):
if isinstance(v, BinaryExpression):
filts.append(v)
elif hasattr(v, 'split'):
filts.append(getattr(E, k) == v)
elif hasattr(v, '__getitem__') or hasattr(v, '__iter__'):
filts.append(getattr(E, k).in_(v))
else:
filts.append(getattr(E, k) == v)
q = q.filter(*filts)
rows = q.all()
session.close()
q_, ses = q(*columns, **filters)
rows = q_.all()
ses.close()
return rows


Expand Down

0 comments on commit a56352c

Please sign in to comment.