diff --git a/great_tables/_gt_data.py b/great_tables/_gt_data.py index a231fad8b..b166e071c 100644 --- a/great_tables/_gt_data.py +++ b/great_tables/_gt_data.py @@ -19,6 +19,7 @@ copy_data, create_empty_frame, get_column_names, + _get_column_dtype, n_rows, to_list, validate_frame, @@ -175,7 +176,11 @@ def render_formats(self, data_tbl: TblData, formats: list[FormatInfo], context: # TODO: I think that this is very inefficient with polars, so # we could either accumulate results and set them per column, or # could always use a pandas DataFrame inside Body? - _set_cell(self.body, row, col, result) + new_body = _set_cell(self.body, row, col, result) + if new_body is not None: + # Some backends do not support inplace operations, but return a new dataframe + # TODO: Consolidate the behaviour of _set_cell + self.body = new_body return self @@ -335,7 +340,7 @@ def align_from_data(self, data: TblData) -> Self: # a Pandas DataFrame or a Polars DataFrame col_classes = [] for col in get_column_names(data): - dtype = data[col].dtype + dtype = _get_column_dtype(data, col) if dtype == "object": # Check whether all values in 'object' columns are strings that diff --git a/great_tables/_tbl_data.py b/great_tables/_tbl_data.py index f963f5e25..3bc412c22 100644 --- a/great_tables/_tbl_data.py +++ b/great_tables/_tbl_data.py @@ -1,8 +1,7 @@ from __future__ import annotations -import warnings import re - +import warnings from functools import singledispatch from typing import TYPE_CHECKING, Any, Callable, Optional, Union @@ -10,34 +9,38 @@ from ._databackend import AbstractBackend - # Define databackend types ---- # These are resolved lazily (e.g. on isinstance checks) when run dynamically, # or imported directly during type checking. if TYPE_CHECKING: + import numpy as np import pandas as pd import polars as pl - import numpy as np + import pyarrow as pa # the class behind selectors from polars.selectors import _selector_proxy_ PdDataFrame = pd.DataFrame PlDataFrame = pl.DataFrame + PyArrowTable = pa.Table + PlSelectExpr = _selector_proxy_ PlExpr = pl.Expr PdSeries = pd.Series PlSeries = pl.Series + PyArrowArray = pa.Array + PyArrowChunkedArray = pa.ChunkedArray PdNA = pd.NA PlNull = pl.Null NpNan = np.nan - DataFrameLike = Union[PdDataFrame, PlDataFrame] - SeriesLike = Union[PdSeries, PlSeries] + DataFrameLike = Union[PdDataFrame, PlDataFrame, PyArrowTable] + SeriesLike = Union[PdSeries, PlSeries, PyArrowArray, PyArrowChunkedArray] TblData = DataFrameLike else: @@ -53,6 +56,9 @@ class PdDataFrame(AbstractBackend): class PlDataFrame(AbstractBackend): _backends = [("polars", "DataFrame")] + class PyArrowTable(AbstractBackend): + _backends = [("pyarrow", "Table")] + class PlSelectExpr(AbstractBackend): _backends = [("polars.selectors", "_selector_proxy_")] @@ -65,6 +71,12 @@ class PdSeries(AbstractBackend): class PlSeries(AbstractBackend): _backends = [("polars", "Series")] + class PyArrowArray(AbstractBackend): + _backends = [("pyarrow", "Array")] + + class PyArrowChunkedArray(AbstractBackend): + _backends = [("pyarrow", "ChunkedArray")] + class PdNA(AbstractBackend): _backends = [("pandas", "NA")] @@ -84,8 +96,11 @@ class SeriesLike(ABC): DataFrameLike.register(PdDataFrame) DataFrameLike.register(PlDataFrame) + DataFrameLike.register(PyArrowTable) SeriesLike.register(PdSeries) SeriesLike.register(PlSeries) + SeriesLike.register(PyArrowArray) + SeriesLike.register(PyArrowChunkedArray) TblData = DataFrameLike @@ -140,6 +155,13 @@ def _(data: PlDataFrame): return data.clone() +@copy_data.register(PyArrowTable) +def _(data: PyArrowTable): + import pyarrow as pa + + return pa.table(data) + + # get_column_names ---- @singledispatch def get_column_names(data: DataFrameLike) -> list[str]: @@ -157,6 +179,11 @@ def _(data: PlDataFrame): return data.columns +@get_column_names.register(PyArrowTable) +def _(data: PyArrowTable): + return data.column_names + + # n_rows ---- @@ -172,6 +199,11 @@ def _(data: Any) -> int: return len(data) +@n_rows.register(PyArrowTable) +def _(data: PyArrowTable) -> int: + return data.num_rows + + # _get_cell ---- @@ -197,6 +229,11 @@ def _(data: Any, row: int, col: str) -> Any: return data.iloc[row, col_ii] +@_get_cell.register(PyArrowTable) +def _(data: PyArrowTable, row: int, column: str) -> Any: + return data.column(column)[row].as_py() + + # _set_cell ---- @@ -218,15 +255,32 @@ def _(data, row: int, column: str, value: Any) -> None: data[row, column] = value +@_set_cell.register(PyArrowTable) +def _(data: PyArrowTable, row: int, column: str, value: Any) -> PyArrowTable: + import pyarrow as pa + + colindex = data.column_names.index(column) + col = data.column(column) + pylist = col.to_pylist() + pylist[row] = value + data = data.set_column(colindex, column, pa.array(pylist)) + return data + + # _get_column_dtype ---- @singledispatch -def _get_column_dtype(data: DataFrameLike, column: str) -> str: +def _get_column_dtype(data: DataFrameLike, column: str) -> Any: """Get the data type for a single column in the input data table""" return data[column].dtype +@_get_column_dtype.register(PyArrowTable) +def _(data: PyArrowTable, column: str) -> Any: + return data.column(column).type + + # reorder ---- @@ -249,6 +303,11 @@ def _(data: PlDataFrame, rows: list[int], columns: list[str]) -> PlDataFrame: return data[rows, columns] +@reorder.register +def _(data: PyArrowTable, rows: list[int], columns: list[str]) -> PyArrowTable: + return data.select(columns).take(rows) + + # group_splits ---- @singledispatch def group_splits(data: DataFrameLike, group_key: str) -> dict[Any, list[int]]: @@ -277,6 +336,20 @@ def _(data: PlDataFrame, group_key: str) -> dict[Any, list[int]]: return res +@group_splits.register +def _(data: PyArrowTable, group_key: str) -> dict[Any, list[int]]: + import pyarrow.compute as pc + + group_col = data.column(group_key) + encoded = group_col.dictionary_encode().combine_chunks() + + d = {} + for idx, group_key in enumerate(encoded.dictionary): + mask = pc.equal(encoded.indices, idx) + d[group_key.as_py()] = pc.indices_nonzero(mask).to_pylist() + return d + + # eval_select ---- SelectExpr: TypeAlias = Union[ @@ -324,12 +397,12 @@ def _( def _(data: PlDataFrame, expr: Union[list[str], _selector_proxy_], strict: bool = True) -> _NamePos: # TODO: how to annotate type of a polars selector? # Seems to be polars.selectors._selector_proxy_. - from ._utils import OrderedSet import polars as pl import polars.selectors as cs - from polars import Expr + from ._utils import OrderedSet + pl_version = _re_version(pl.__version__) expand_opts = {"strict": False} if pl_version >= (0, 20, 30) else {} @@ -370,9 +443,25 @@ def _(data: PlDataFrame, expr: Union[list[str], _selector_proxy_], strict: bool return [(col, col_pos[col]) for col in final_columns] +@eval_select.register +def _( + data: PyArrowTable, expr: Union[list[str], _selector_proxy_], strict: bool = True +) -> _NamePos: + if isinstance(expr, (str, int)): + expr = [expr] + + if isinstance(expr, list): + return _eval_select_from_list(data.column_names, expr) + elif callable(expr): + col_pos = {k: ii for ii, k in enumerate(data.column_names)} + return [(col, col_pos[col]) for col in data.column_names if expr(col)] + + raise NotImplementedError(f"Unsupported selection expr: {expr}") + + def _validate_selector_list(selectors: list, strict=True): - from polars.selectors import is_selector from polars import Expr + from polars.selectors import is_selector for ii, sel in enumerate(selectors): if isinstance(sel, Expr): @@ -430,6 +519,13 @@ def _(df: PlDataFrame): return df.clear().cast(pl.Utf8).clear(len(df)) +@create_empty_frame.register +def _(df: PyArrowTable): + import pyarrow as pa + + return pa.table({col: pa.nulls(df.num_rows, type=pa.string()) for col in df.column_names}) + + @singledispatch def copy_frame(df: DataFrameLike) -> DataFrameLike: """Return a copy of the input DataFrame""" @@ -446,6 +542,13 @@ def _(df: PlDataFrame): return df.clone() +@copy_frame.register +def _(df: PyArrowTable): + import pyarrow as pa + + return pa.table({col: pa.array(df.column(col)) for col in df.column_names}) + + # cast_frame_to_string ---- @@ -475,6 +578,13 @@ def _(df: PlDataFrame): ) +@cast_frame_to_string.register +def _(df: PyArrowTable): + import pyarrow as pa + + return pa.table({col: pa.array(df.column(col).cast(pa.string())) for col in df.column_names}) + + # replace_null_frame ---- @@ -497,6 +607,19 @@ def _(df: PlDataFrame, replacement: PlDataFrame): return df.select(exprs) +@replace_null_frame.register +def _(df: PyArrowTable, replacement: PyArrowTable): + import pyarrow as pa + import pyarrow.compute as pc + + return pa.table( + { + col: pc.if_else(pc.is_null(df.column(col)), replacement.column(col), df.column(col)) + for col in df.column_names + } + ) + + @singledispatch def to_list(ser: SeriesLike) -> list[Any]: raise NotImplementedError(f"Unsupported type: {type(ser)}") @@ -512,6 +635,16 @@ def _(ser: PlSeries) -> list[Any]: return ser.to_list() +@to_list.register +def _(ser: PyArrowArray) -> list[Any]: + return ser.to_pylist() + + +@to_list.register +def _(ser: PyArrowChunkedArray) -> list[Any]: + return ser.to_pylist() + + # is_series ---- @@ -530,6 +663,16 @@ def _(ser: PlSeries) -> bool: return True +@is_series.register +def _(ser: PyArrowArray) -> bool: + return True + + +@is_series.register +def _(ser: PyArrowChunkedArray) -> bool: + return True + + # mutate ---- @@ -573,6 +716,21 @@ def _(df: PlDataFrame, expr: PlExpr) -> list[Any]: return res.to_list() +@eval_transform.register +def _(df: PyArrowTable, expr: Callable[[PyArrowTable], PyArrowArray]) -> list[Any]: + res = expr(df) + + if not isinstance(res, PyArrowArray): + raise ValueError(f"Result must be an Arrow Array. Received {type(res)}") + elif not len(res) == len(df): + raise ValueError( + f"Result must be same length as input data. Observed different lengths." + f"\n\nInput data: {df.num_rows}.\nResult: {len(res)}." + ) + + return res.to_pylist() + + @singledispatch def is_na(df: DataFrameLike, x: Any) -> bool: raise NotImplementedError(f"Unsupported type: {type(df)}") @@ -588,13 +746,21 @@ def _(df: PdDataFrame, x: Any) -> bool: @is_na.register(Agnostic) @is_na.register def _(df: PlDataFrame, x: Any) -> bool: - import polars as pl - from math import isnan + import polars as pl + return isinstance(x, (pl.Null, type(None))) or (isinstance(x, float) and isnan(x)) +@is_na.register +def _(df: PyArrowTable, x: Any) -> bool: + import pyarrow as pa + + arr = pa.array([x]) + return arr.is_null().to_pylist()[0] or arr.is_nan().to_pylist()[0] + + @singledispatch def validate_frame(df: DataFrameLike) -> DataFrameLike: """Raises an error if a DataFrame is not supported by Great Tables. @@ -646,6 +812,16 @@ def _(df: PlDataFrame) -> PlDataFrame: return df +@validate_frame.register +def _(df: PyArrowTable) -> PyArrowTable: + warnings.warn("PyArrow Table support is currently experimental.") + + if len(set(df.column_names)) != len(df.column_names): + raise ValueError("Column names must be unique.") + + return df + + # to_frame ---- @@ -679,3 +855,17 @@ def _(ser: PdSeries, name: Optional[str] = None) -> PdDataFrame: @to_frame.register def _(ser: PlSeries, name: Optional[str] = None) -> PlDataFrame: return ser.to_frame(name) + + +@to_frame.register +def _(ser: PyArrowArray, name: Optional[str] = None) -> PyArrowTable: + import pyarrow as pa + + return pa.table({name: ser}) + + +@to_frame.register +def _(ser: PyArrowChunkedArray, name: Optional[str] = None) -> PyArrowTable: + import pyarrow as pa + + return pa.table({name: ser}) diff --git a/tests/__snapshots__/test_tbl_data.ambr b/tests/__snapshots__/test_tbl_data.ambr index 6cebc27a3..6aae98754 100644 --- a/tests/__snapshots__/test_tbl_data.ambr +++ b/tests/__snapshots__/test_tbl_data.ambr @@ -1,4 +1,67 @@ # serializer version: 1 +# name: test_frame_rendering[arrow] + ''' +
+