diff --git a/labm8/BUILD b/labm8/BUILD index 3e0ebe7c..267efe19 100644 --- a/labm8/BUILD +++ b/labm8/BUILD @@ -547,6 +547,28 @@ py_test( ], ) +py_library( + name = "pdutil", + srcs = ["pdutil.py"], + visibility = ["//visibility:public"], + deps = [ + ":app", + ":sqlutil", + "//third_party/py/pandas", + ], +) + +py_test( + name = "pdutil_test", + srcs = ["pdutil_test.py"], + deps = [ + ":pdutil", + ":sqlutil", + ":test", + "//third_party/py/sqlalchemy", + ], +) + py_library( name = "ppar", srcs = ["ppar.py"], @@ -627,7 +649,6 @@ py_library( ":pbutil", ":text", "//third_party/py/absl", - "//third_party/py/pandas", "//third_party/py/sqlalchemy", ], ) diff --git a/labm8/pdutil.py b/labm8/pdutil.py new file mode 100644 index 00000000..f163ec1b --- /dev/null +++ b/labm8/pdutil.py @@ -0,0 +1,61 @@ +# Copyright 2014-2019 Chris Cummins . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility code for working with pandas.""" +import pandas as pd +import typing +from absl import flags as absl_flags + +from labm8 import sqlutil + +FLAGS = absl_flags.FLAGS + + +def QueryToDataFrame(session: sqlutil.Session, + query: sqlutil.Query) -> pd.DataFrame: + """Read query results to a Pandas DataFrame. + + Args: + session: A database session. + query: The query to run. + + Returns: + A Pandas DataFrame. + """ + return pd.read_sql(query.statement, session.bind) + + +def ModelToDataFrame( + session: sqlutil.Session, + model, + columns: typing.Optional[typing.List[str]] = None, + query_identity=lambda q: q, +): + """Construct and execute a query reads an object's fields to a dataframe. + + Args: + session: A database session. + model: A database mapped object. + columns: A list of column names, where each element is a column mapped to + the model. If not provided, all column names are used. + query_identity: A function which takes the produced query and returns a + query. Use this to implement filtering of the query results. + + Returns: + A Pandas DataFrame with one column for each field. + """ + columns = columns or ColumnNames(model) + query = session.query(*[getattr(model, column) for column in columns]) + df = QueryToDataFrame(session, query_identity(query)) + df.columns = columns + return df diff --git a/labm8/pdutil_test.py b/labm8/pdutil_test.py new file mode 100644 index 00000000..1d2a91e4 --- /dev/null +++ b/labm8/pdutil_test.py @@ -0,0 +1,67 @@ +# Copyright 2014-2019 Chris Cummins . +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Unit tests for //labm8:sqlutil.""" +import sqlalchemy as sql +from sqlalchemy.ext import declarative + +from labm8 import pdutil +from labm8 import sqlutil + + +def test_QueryToDataFrame_column_names(): + """Test that expected column names are set.""" + base = declarative.declarative_base() + + class Table(base): + __tablename__ = 'test' + col_a = sql.Column(sql.Integer, primary_key=True) + col_b = sql.Column(sql.Integer) + + db = sqlutil.Database('sqlite://', base) + with db.Session() as s: + df = pdutil.QueryToDataFrame(s, s.query(Table.col_a, Table.col_b)) + + assert list(df.columns.values) == ['col_a', 'col_b'] + + +def test_ModelToDataFrame_column_names(): + """Test that expected column names are set.""" + base = declarative.declarative_base() + + class Table(base): + __tablename__ = 'test' + col_a = sql.Column(sql.Integer, primary_key=True) + col_b = sql.Column(sql.Integer) + + db = sqlutil.Database('sqlite://', base) + with db.Session() as s: + df = pdutil.ModelToDataFrame(s, Table) + + assert list(df.columns.values) == ['col_a', 'col_b'] + + +def test_QueryToDataFrame_explicit_column_names(): + """Test that expected column names are set.""" + base = declarative.declarative_base() + + class Table(base): + __tablename__ = 'test' + col_a = sql.Column(sql.Integer, primary_key=True) + col_b = sql.Column(sql.Integer) + + db = sqlutil.Database('sqlite://', base) + with db.Session() as s: + df = pdutil.ModelToDataFrame(s, Table, ['col_b']) + + assert list(df.columns.values) == ['col_b'] diff --git a/labm8/sqlutil.py b/labm8/sqlutil.py index b175aa6c..fb3e94d4 100644 --- a/labm8/sqlutil.py +++ b/labm8/sqlutil.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. """Utility code for working with sqlalchemy.""" +import time + import collections import contextlib import pathlib -import time -import typing - -import pandas as pd import sqlalchemy as sql +import typing from absl import flags as absl_flags from sqlalchemy import func from sqlalchemy import orm @@ -468,45 +467,6 @@ def __repr__(self) -> str: return self.url -def QueryToDataFrame(session: Session, query: Query) -> pd.DataFrame: - """Read query results to a Pandas DataFrame. - - Args: - session: A database session. - query: The query to run. - - Returns: - A Pandas DataFrame. - """ - return pd.read_sql(query.statement, session.bind) - - -def ModelToDataFrame( - session: Session, - model, - columns: typing.Optional[typing.List[str]] = None, - query_identity=lambda q: q, -): - """Construct and execute a query reads an object's fields to a dataframe. - - Args: - session: A database session. - model: A database mapped object. - columns: A list of column names, where each element is a column mapped to - the model. If not provided, all column names are used. - query_identity: A function which takes the produced query and returns a - query. Use this to implement filtering of the query results. - - Returns: - A Pandas DataFrame with one column for each field. - """ - columns = columns or ColumnNames(model) - query = session.query(*[getattr(model, column) for column in columns]) - df = QueryToDataFrame(session, query_identity(query)) - df.columns = columns - return df - - class TablenameFromClassNameMixin(object): """A class mixin which derives __tablename__ from the class name. diff --git a/labm8/sqlutil_test.py b/labm8/sqlutil_test.py index 618d547c..d965ab7a 100644 --- a/labm8/sqlutil_test.py +++ b/labm8/sqlutil_test.py @@ -13,10 +13,9 @@ # limitations under the License. """Unit tests for //labm8:sqlutil.""" import pathlib -import typing - import pytest import sqlalchemy as sql +import typing from sqlalchemy.ext import declarative from labm8 import sqlutil @@ -114,54 +113,6 @@ class Table(base, sqlutil.TablenameFromClassNameMixin): assert sqlutil.ColumnNames(instance) == ['col_a', 'col_b'] -def test_QueryToDataFrame_column_names(): - """Test that expected column names are set.""" - base = declarative.declarative_base() - - class Table(base): - __tablename__ = 'test' - col_a = sql.Column(sql.Integer, primary_key=True) - col_b = sql.Column(sql.Integer) - - db = sqlutil.Database('sqlite://', base) - with db.Session() as s: - df = sqlutil.QueryToDataFrame(s, s.query(Table.col_a, Table.col_b)) - - assert list(df.columns.values) == ['col_a', 'col_b'] - - -def test_ModelToDataFrame_column_names(): - """Test that expected column names are set.""" - base = declarative.declarative_base() - - class Table(base): - __tablename__ = 'test' - col_a = sql.Column(sql.Integer, primary_key=True) - col_b = sql.Column(sql.Integer) - - db = sqlutil.Database('sqlite://', base) - with db.Session() as s: - df = sqlutil.ModelToDataFrame(s, Table) - - assert list(df.columns.values) == ['col_a', 'col_b'] - - -def test_QueryToDataFrame_explicit_column_names(): - """Test that expected column names are set.""" - base = declarative.declarative_base() - - class Table(base): - __tablename__ = 'test' - col_a = sql.Column(sql.Integer, primary_key=True) - col_b = sql.Column(sql.Integer) - - db = sqlutil.Database('sqlite://', base) - with db.Session() as s: - df = sqlutil.ModelToDataFrame(s, Table, ['col_b']) - - assert list(df.columns.values) == ['col_b'] - - def test_AllColumnNames_invalid_object(): """TypeError raised when called on an invalid object."""