diff --git a/great_tables/_gt_data.py b/great_tables/_gt_data.py index a231fad8b..b166e071c 100644 --- a/great_tables/_gt_data.py +++ b/great_tables/_gt_data.py @@ -19,6 +19,7 @@ copy_data, create_empty_frame, get_column_names, + _get_column_dtype, n_rows, to_list, validate_frame, @@ -175,7 +176,11 @@ def render_formats(self, data_tbl: TblData, formats: list[FormatInfo], context: # TODO: I think that this is very inefficient with polars, so # we could either accumulate results and set them per column, or # could always use a pandas DataFrame inside Body? - _set_cell(self.body, row, col, result) + new_body = _set_cell(self.body, row, col, result) + if new_body is not None: + # Some backends do not support inplace operations, but return a new dataframe + # TODO: Consolidate the behaviour of _set_cell + self.body = new_body return self @@ -335,7 +340,7 @@ def align_from_data(self, data: TblData) -> Self: # a Pandas DataFrame or a Polars DataFrame col_classes = [] for col in get_column_names(data): - dtype = data[col].dtype + dtype = _get_column_dtype(data, col) if dtype == "object": # Check whether all values in 'object' columns are strings that diff --git a/great_tables/_tbl_data.py b/great_tables/_tbl_data.py index f963f5e25..3bc412c22 100644 --- a/great_tables/_tbl_data.py +++ b/great_tables/_tbl_data.py @@ -1,8 +1,7 @@ from __future__ import annotations -import warnings import re - +import warnings from functools import singledispatch from typing import TYPE_CHECKING, Any, Callable, Optional, Union @@ -10,34 +9,38 @@ from ._databackend import AbstractBackend - # Define databackend types ---- # These are resolved lazily (e.g. on isinstance checks) when run dynamically, # or imported directly during type checking. if TYPE_CHECKING: + import numpy as np import pandas as pd import polars as pl - import numpy as np + import pyarrow as pa # the class behind selectors from polars.selectors import _selector_proxy_ PdDataFrame = pd.DataFrame PlDataFrame = pl.DataFrame + PyArrowTable = pa.Table + PlSelectExpr = _selector_proxy_ PlExpr = pl.Expr PdSeries = pd.Series PlSeries = pl.Series + PyArrowArray = pa.Array + PyArrowChunkedArray = pa.ChunkedArray PdNA = pd.NA PlNull = pl.Null NpNan = np.nan - DataFrameLike = Union[PdDataFrame, PlDataFrame] - SeriesLike = Union[PdSeries, PlSeries] + DataFrameLike = Union[PdDataFrame, PlDataFrame, PyArrowTable] + SeriesLike = Union[PdSeries, PlSeries, PyArrowArray, PyArrowChunkedArray] TblData = DataFrameLike else: @@ -53,6 +56,9 @@ class PdDataFrame(AbstractBackend): class PlDataFrame(AbstractBackend): _backends = [("polars", "DataFrame")] + class PyArrowTable(AbstractBackend): + _backends = [("pyarrow", "Table")] + class PlSelectExpr(AbstractBackend): _backends = [("polars.selectors", "_selector_proxy_")] @@ -65,6 +71,12 @@ class PdSeries(AbstractBackend): class PlSeries(AbstractBackend): _backends = [("polars", "Series")] + class PyArrowArray(AbstractBackend): + _backends = [("pyarrow", "Array")] + + class PyArrowChunkedArray(AbstractBackend): + _backends = [("pyarrow", "ChunkedArray")] + class PdNA(AbstractBackend): _backends = [("pandas", "NA")] @@ -84,8 +96,11 @@ class SeriesLike(ABC): DataFrameLike.register(PdDataFrame) DataFrameLike.register(PlDataFrame) + DataFrameLike.register(PyArrowTable) SeriesLike.register(PdSeries) SeriesLike.register(PlSeries) + SeriesLike.register(PyArrowArray) + SeriesLike.register(PyArrowChunkedArray) TblData = DataFrameLike @@ -140,6 +155,13 @@ def _(data: PlDataFrame): return data.clone() +@copy_data.register(PyArrowTable) +def _(data: PyArrowTable): + import pyarrow as pa + + return pa.table(data) + + # get_column_names ---- @singledispatch def get_column_names(data: DataFrameLike) -> list[str]: @@ -157,6 +179,11 @@ def _(data: PlDataFrame): return data.columns +@get_column_names.register(PyArrowTable) +def _(data: PyArrowTable): + return data.column_names + + # n_rows ---- @@ -172,6 +199,11 @@ def _(data: Any) -> int: return len(data) +@n_rows.register(PyArrowTable) +def _(data: PyArrowTable) -> int: + return data.num_rows + + # _get_cell ---- @@ -197,6 +229,11 @@ def _(data: Any, row: int, col: str) -> Any: return data.iloc[row, col_ii] +@_get_cell.register(PyArrowTable) +def _(data: PyArrowTable, row: int, column: str) -> Any: + return data.column(column)[row].as_py() + + # _set_cell ---- @@ -218,15 +255,32 @@ def _(data, row: int, column: str, value: Any) -> None: data[row, column] = value +@_set_cell.register(PyArrowTable) +def _(data: PyArrowTable, row: int, column: str, value: Any) -> PyArrowTable: + import pyarrow as pa + + colindex = data.column_names.index(column) + col = data.column(column) + pylist = col.to_pylist() + pylist[row] = value + data = data.set_column(colindex, column, pa.array(pylist)) + return data + + # _get_column_dtype ---- @singledispatch -def _get_column_dtype(data: DataFrameLike, column: str) -> str: +def _get_column_dtype(data: DataFrameLike, column: str) -> Any: """Get the data type for a single column in the input data table""" return data[column].dtype +@_get_column_dtype.register(PyArrowTable) +def _(data: PyArrowTable, column: str) -> Any: + return data.column(column).type + + # reorder ---- @@ -249,6 +303,11 @@ def _(data: PlDataFrame, rows: list[int], columns: list[str]) -> PlDataFrame: return data[rows, columns] +@reorder.register +def _(data: PyArrowTable, rows: list[int], columns: list[str]) -> PyArrowTable: + return data.select(columns).take(rows) + + # group_splits ---- @singledispatch def group_splits(data: DataFrameLike, group_key: str) -> dict[Any, list[int]]: @@ -277,6 +336,20 @@ def _(data: PlDataFrame, group_key: str) -> dict[Any, list[int]]: return res +@group_splits.register +def _(data: PyArrowTable, group_key: str) -> dict[Any, list[int]]: + import pyarrow.compute as pc + + group_col = data.column(group_key) + encoded = group_col.dictionary_encode().combine_chunks() + + d = {} + for idx, group_key in enumerate(encoded.dictionary): + mask = pc.equal(encoded.indices, idx) + d[group_key.as_py()] = pc.indices_nonzero(mask).to_pylist() + return d + + # eval_select ---- SelectExpr: TypeAlias = Union[ @@ -324,12 +397,12 @@ def _( def _(data: PlDataFrame, expr: Union[list[str], _selector_proxy_], strict: bool = True) -> _NamePos: # TODO: how to annotate type of a polars selector? # Seems to be polars.selectors._selector_proxy_. - from ._utils import OrderedSet import polars as pl import polars.selectors as cs - from polars import Expr + from ._utils import OrderedSet + pl_version = _re_version(pl.__version__) expand_opts = {"strict": False} if pl_version >= (0, 20, 30) else {} @@ -370,9 +443,25 @@ def _(data: PlDataFrame, expr: Union[list[str], _selector_proxy_], strict: bool return [(col, col_pos[col]) for col in final_columns] +@eval_select.register +def _( + data: PyArrowTable, expr: Union[list[str], _selector_proxy_], strict: bool = True +) -> _NamePos: + if isinstance(expr, (str, int)): + expr = [expr] + + if isinstance(expr, list): + return _eval_select_from_list(data.column_names, expr) + elif callable(expr): + col_pos = {k: ii for ii, k in enumerate(data.column_names)} + return [(col, col_pos[col]) for col in data.column_names if expr(col)] + + raise NotImplementedError(f"Unsupported selection expr: {expr}") + + def _validate_selector_list(selectors: list, strict=True): - from polars.selectors import is_selector from polars import Expr + from polars.selectors import is_selector for ii, sel in enumerate(selectors): if isinstance(sel, Expr): @@ -430,6 +519,13 @@ def _(df: PlDataFrame): return df.clear().cast(pl.Utf8).clear(len(df)) +@create_empty_frame.register +def _(df: PyArrowTable): + import pyarrow as pa + + return pa.table({col: pa.nulls(df.num_rows, type=pa.string()) for col in df.column_names}) + + @singledispatch def copy_frame(df: DataFrameLike) -> DataFrameLike: """Return a copy of the input DataFrame""" @@ -446,6 +542,13 @@ def _(df: PlDataFrame): return df.clone() +@copy_frame.register +def _(df: PyArrowTable): + import pyarrow as pa + + return pa.table({col: pa.array(df.column(col)) for col in df.column_names}) + + # cast_frame_to_string ---- @@ -475,6 +578,13 @@ def _(df: PlDataFrame): ) +@cast_frame_to_string.register +def _(df: PyArrowTable): + import pyarrow as pa + + return pa.table({col: pa.array(df.column(col).cast(pa.string())) for col in df.column_names}) + + # replace_null_frame ---- @@ -497,6 +607,19 @@ def _(df: PlDataFrame, replacement: PlDataFrame): return df.select(exprs) +@replace_null_frame.register +def _(df: PyArrowTable, replacement: PyArrowTable): + import pyarrow as pa + import pyarrow.compute as pc + + return pa.table( + { + col: pc.if_else(pc.is_null(df.column(col)), replacement.column(col), df.column(col)) + for col in df.column_names + } + ) + + @singledispatch def to_list(ser: SeriesLike) -> list[Any]: raise NotImplementedError(f"Unsupported type: {type(ser)}") @@ -512,6 +635,16 @@ def _(ser: PlSeries) -> list[Any]: return ser.to_list() +@to_list.register +def _(ser: PyArrowArray) -> list[Any]: + return ser.to_pylist() + + +@to_list.register +def _(ser: PyArrowChunkedArray) -> list[Any]: + return ser.to_pylist() + + # is_series ---- @@ -530,6 +663,16 @@ def _(ser: PlSeries) -> bool: return True +@is_series.register +def _(ser: PyArrowArray) -> bool: + return True + + +@is_series.register +def _(ser: PyArrowChunkedArray) -> bool: + return True + + # mutate ---- @@ -573,6 +716,21 @@ def _(df: PlDataFrame, expr: PlExpr) -> list[Any]: return res.to_list() +@eval_transform.register +def _(df: PyArrowTable, expr: Callable[[PyArrowTable], PyArrowArray]) -> list[Any]: + res = expr(df) + + if not isinstance(res, PyArrowArray): + raise ValueError(f"Result must be an Arrow Array. Received {type(res)}") + elif not len(res) == len(df): + raise ValueError( + f"Result must be same length as input data. Observed different lengths." + f"\n\nInput data: {df.num_rows}.\nResult: {len(res)}." + ) + + return res.to_pylist() + + @singledispatch def is_na(df: DataFrameLike, x: Any) -> bool: raise NotImplementedError(f"Unsupported type: {type(df)}") @@ -588,13 +746,21 @@ def _(df: PdDataFrame, x: Any) -> bool: @is_na.register(Agnostic) @is_na.register def _(df: PlDataFrame, x: Any) -> bool: - import polars as pl - from math import isnan + import polars as pl + return isinstance(x, (pl.Null, type(None))) or (isinstance(x, float) and isnan(x)) +@is_na.register +def _(df: PyArrowTable, x: Any) -> bool: + import pyarrow as pa + + arr = pa.array([x]) + return arr.is_null().to_pylist()[0] or arr.is_nan().to_pylist()[0] + + @singledispatch def validate_frame(df: DataFrameLike) -> DataFrameLike: """Raises an error if a DataFrame is not supported by Great Tables. @@ -646,6 +812,16 @@ def _(df: PlDataFrame) -> PlDataFrame: return df +@validate_frame.register +def _(df: PyArrowTable) -> PyArrowTable: + warnings.warn("PyArrow Table support is currently experimental.") + + if len(set(df.column_names)) != len(df.column_names): + raise ValueError("Column names must be unique.") + + return df + + # to_frame ---- @@ -679,3 +855,17 @@ def _(ser: PdSeries, name: Optional[str] = None) -> PdDataFrame: @to_frame.register def _(ser: PlSeries, name: Optional[str] = None) -> PlDataFrame: return ser.to_frame(name) + + +@to_frame.register +def _(ser: PyArrowArray, name: Optional[str] = None) -> PyArrowTable: + import pyarrow as pa + + return pa.table({name: ser}) + + +@to_frame.register +def _(ser: PyArrowChunkedArray, name: Optional[str] = None) -> PyArrowTable: + import pyarrow as pa + + return pa.table({name: ser}) diff --git a/tests/__snapshots__/test_tbl_data.ambr b/tests/__snapshots__/test_tbl_data.ambr index 6cebc27a3..6aae98754 100644 --- a/tests/__snapshots__/test_tbl_data.ambr +++ b/tests/__snapshots__/test_tbl_data.ambr @@ -1,4 +1,67 @@ # serializer version: 1 +# name: test_frame_rendering[arrow] + ''' + + + $1.00 + a + 4 + + + $2.00 + b + 5 + + + $3.00 + c + 6 + + + ''' +# --- +# name: test_frame_rendering[pandas] + ''' + + + $1.00 + a + 4 + + + $2.00 + b + 5 + + + $3.00 + c + 6 + + + ''' +# --- +# name: test_frame_rendering[polars] + ''' + + + $1.00 + a + 4 + + + $2.00 + b + 5 + + + $3.00 + c + 6 + + + ''' +# --- # name: test_validate_frame_non_str_cols_warning ''' pandas DataFrame contains non-string column names. Coercing to strings. Here are the first few non-string columns: diff --git a/tests/test_substitutions.py b/tests/test_substitutions.py index c45e95dc4..92b42c215 100644 --- a/tests/test_substitutions.py +++ b/tests/test_substitutions.py @@ -3,6 +3,7 @@ import numpy as np import pandas as pd import polars as pl +import pyarrow as pa import polars.testing import pytest from great_tables import GT @@ -10,7 +11,11 @@ from great_tables._substitution import SubMissing, SubZero from great_tables._tbl_data import DataFrameLike, to_list -params_frames = [pytest.param(pd.DataFrame, id="pandas"), pytest.param(pl.DataFrame, id="polars")] +params_frames = [ + pytest.param(pd.DataFrame, id="pandas"), + pytest.param(pl.DataFrame, id="polars"), + pytest.param(pa.table, id="arrow"), +] @pytest.fixture(params=params_frames, scope="function") @@ -20,7 +25,7 @@ def df(request) -> DataFrameLike: @pytest.fixture(params=params_frames, scope="function") def df_empty(request) -> DataFrameLike: - return request.param() + return request.param({}) def assert_frame_equal(src: DataFrameLike, target: DataFrameLike): diff --git a/tests/test_tbl_data.py b/tests/test_tbl_data.py index dfc038fdd..335c06fb3 100644 --- a/tests/test_tbl_data.py +++ b/tests/test_tbl_data.py @@ -1,8 +1,11 @@ import math import pandas as pd import polars as pl +import pyarrow as pa import polars.testing import pytest +from great_tables import GT +from great_tables._utils_render_html import create_body_component_h from great_tables._tbl_data import ( DataFrameLike, SeriesLike, @@ -18,11 +21,22 @@ is_series, reorder, to_frame, + to_list, validate_frame, + copy_frame, ) -params_frames = [pytest.param(pd.DataFrame, id="pandas"), pytest.param(pl.DataFrame, id="polars")] -params_series = [pytest.param(pd.Series, id="pandas"), pytest.param(pl.Series, id="polars")] +params_frames = [ + pytest.param(pd.DataFrame, id="pandas"), + pytest.param(pl.DataFrame, id="polars"), + pytest.param(pa.table, id="arrow"), +] +params_series = [ + pytest.param(pd.Series, id="pandas"), + pytest.param(pl.Series, id="polars"), + pytest.param(pa.array, id="arrow"), + pytest.param(lambda a: pa.chunked_array([a]), id="arrow-chunked"), +] @pytest.fixture(params=params_frames, scope="function") @@ -40,6 +54,8 @@ def assert_frame_equal(src, target): pd.testing.assert_frame_equal(src, target) elif isinstance(src, pl.DataFrame): pl.testing.assert_frame_equal(src, target) + elif isinstance(src, pa.Table): + assert src.equals(target) else: raise NotImplementedError(f"Unsupported data type: {type(src)}") @@ -50,7 +66,8 @@ def test_get_column_names(df: DataFrameLike): def test_get_column_dtypes(df: DataFrameLike): - assert _get_column_dtype(df, "col1") == df["col1"].dtype + col1 = df["col1"] + assert _get_column_dtype(df, "col1") == getattr(col1, "dtype", getattr(col1, "type", None)) def test_get_cell(df: DataFrameLike): @@ -58,14 +75,27 @@ def test_get_cell(df: DataFrameLike): def test_set_cell(df: DataFrameLike): - expected = df.__class__({"col1": [1, 2, 3], "col2": ["a", "x", "c"], "col3": [4.0, 5.0, 6.0]}) - _set_cell(df, 1, "col2", "x") - assert_frame_equal(df, expected) + expected_data = {"col1": [1, 2, 3], "col2": ["a", "x", "c"], "col3": [4.0, 5.0, 6.0]} + if isinstance(df, pa.Table): + expected = pa.table(expected_data) + else: + expected = df.__class__(expected_data) + + new_df = _set_cell(df, 1, "col2", "x") + if new_df is None: + # Some implementations do in-place modifications + new_df = df + assert_frame_equal(new_df, expected) def test_reorder(df: DataFrameLike): res = reorder(df, [0, 2], ["col2"]) - dst = df.__class__({"col2": ["a", "c"]}) + + expected_data = {"col2": ["a", "c"]} + if isinstance(df, pa.Table): + dst = pa.table(expected_data) + else: + dst = df.__class__(expected_data) if isinstance(dst, pd.DataFrame): dst.index = pd.Index([0, 2]) @@ -79,6 +109,21 @@ def test_eval_select_with_list(df: DataFrameLike, expr): assert sel == [("col2", 1), ("col1", 0)] +def test_eval_select_with_callable(df: DataFrameLike): + def expr(col): + return col == "col2" + + if isinstance(df, pl.DataFrame): + # Polars does not support callable expressions + with pytest.raises(TypeError) as exc_info: + eval_select(df, expr) + assert "Unsupported selection expr type:" in str(exc_info.value.args[0]) + return + + sel = eval_select(df, expr) + assert sel == [("col2", 1)] + + @pytest.mark.parametrize( "expr", [ @@ -132,7 +177,7 @@ def test_eval_selector_polars_list_raises(): assert "entry 1 is type: " in str(exc_info.value.args[0]) -@pytest.mark.parametrize("Frame", [pd.DataFrame, pl.DataFrame]) +@pytest.mark.parametrize("Frame", [pd.DataFrame, pl.DataFrame, pa.table]) def test_group_splits_pd(Frame): df = Frame({"g": ["b", "a", "b", "c"]}) @@ -175,8 +220,21 @@ def test_create_empty_frame(df: DataFrameLike): if isinstance(res, pd.DataFrame): dst = pd.DataFrame({"col1": col, "col2": col, "col3": col}, dtype="string") - else: + elif isinstance(res, pl.DataFrame): dst = pl.DataFrame({"col1": col, "col2": col, "col3": col}).cast(pl.Utf8) + elif isinstance(res, pa.Table): + dst = pa.table( + {"col1": col, "col2": col, "col3": col}, + schema=pa.schema( + ( + pa.field("col1", pa.string()), + pa.field("col2", pa.string()), + pa.field("col3", pa.string()), + ) + ), + ) + else: + raise ValueError(f"Unsupported data type: {type(res)}") assert_frame_equal(res, dst) @@ -227,6 +285,8 @@ def test_to_frame(ser: SeriesLike): assert_frame_equal(df, pl.DataFrame({"x": [1.0, 2.0, None]})) elif isinstance(ser, pd.Series): assert_frame_equal(df, pd.DataFrame({"x": [1.0, 2.0, None]})) + elif isinstance(ser, (pa.Array, pa.ChunkedArray)): + assert_frame_equal(df, pa.table({"x": [1.0, 2.0, None]})) else: raise AssertionError(f"Unexpected series type: {type(ser)}") @@ -239,6 +299,12 @@ def test_is_series_false(): assert not is_series(1) +def test_to_list(ser: SeriesLike): + pylist = to_list(ser) + assert len(pylist) == 3 + assert pylist[:2] == [1.0, 2.0] + + def test_cast_frame_to_string_polars_list_col(): df = pl.DataFrame({"x": [[1, 2], [3]], "y": [1, None], "z": [{"a": 1}, {"a": 2}]}) new_df = cast_frame_to_string(df) @@ -246,3 +312,14 @@ def test_cast_frame_to_string_polars_list_col(): assert new_df["x"].dtype.is_(pl.String) assert new_df["y"].dtype.is_(pl.String) assert new_df["z"].dtype.is_(pl.String) + + +def test_frame_rendering(df: DataFrameLike, snapshot): + gt = GT(df).fmt_number(columns="col3", decimals=0).fmt_currency(columns="col1") + assert create_body_component_h(gt._build_data("html")) == snapshot + + +def test_copy_frame(df: DataFrameLike): + copy_df = copy_frame(df) + assert id(copy_df) != id(df) + assert_frame_equal(copy_df, df)