Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC, feat: LazyFrame.collect kwargs #1734

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
5 changes: 3 additions & 2 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from narwhals._dask.group_by import DaskLazyGroupBy
from narwhals._dask.namespace import DaskNamespace
from narwhals._dask.typing import IntoDaskExpr
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals.dtypes import DType
from narwhals.utils import Version

Expand Down Expand Up @@ -79,12 +80,12 @@ def with_columns(self, *exprs: DaskExpr, **named_exprs: DaskExpr) -> Self:
df = df.assign(**new_series)
return self._from_native_frame(df)

def collect(self) -> Any:
def collect(self: Self, **kwargs: Any) -> PandasLikeDataFrame:
import pandas as pd

from narwhals._pandas_like.dataframe import PandasLikeDataFrame

result = self._native_frame.compute()
result = self._native_frame.compute(**kwargs)
return PandasLikeDataFrame(
result,
implementation=Implementation.PANDAS,
Expand Down
75 changes: 63 additions & 12 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Iterable
from typing import Literal
from typing import Sequence
from typing import overload

from narwhals._duckdb.utils import native_to_narwhals_dtype
from narwhals._duckdb.utils import parse_exprs_and_named_exprs
Expand All @@ -27,10 +28,13 @@
import pyarrow as pa
from typing_extensions import Self

from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._duckdb.expr import DuckDBExpr
from narwhals._duckdb.group_by import DuckDBGroupBy
from narwhals._duckdb.namespace import DuckDBNamespace
from narwhals._duckdb.series import DuckDBInterchangeSeries
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals._polars.dataframe import PolarsDataFrame
from narwhals.dtypes import DType


Expand Down Expand Up @@ -76,20 +80,67 @@ def __getitem__(self, item: str) -> DuckDBInterchangeSeries:
self._native_frame.select(item), version=self._version
)

def collect(self) -> Any:
try:
import pyarrow as pa # ignore-banned-import
except ModuleNotFoundError as exc: # pragma: no cover
msg = "PyArrow>=11.0.0 is required to collect `LazyFrame` backed by DuckDcollect `LazyFrame` backed by DuckDB"
raise ModuleNotFoundError(msg) from exc
@overload
def collect(self, return_type: Literal["pyarrow"] = "pyarrow") -> ArrowDataFrame: ...

from narwhals._arrow.dataframe import ArrowDataFrame
@overload
def collect(self, return_type: Literal["pandas"]) -> PandasLikeDataFrame: ...

return ArrowDataFrame(
native_dataframe=self._native_frame.arrow(),
backend_version=parse_version(pa.__version__),
version=self._version,
)
@overload
def collect(self, return_type: Literal["polars"]) -> PolarsDataFrame: ...

def collect(
self,
return_type: Literal["pyarrow", "pandas", "polars"] = "pyarrow",
) -> ArrowDataFrame | PandasLikeDataFrame | PolarsDataFrame:
if return_type == "pyarrow":
try:
import pyarrow as pa # ignore-banned-import
except ModuleNotFoundError as exc: # pragma: no cover
msg = (
"PyArrow>=11.0.0 is required to collect `LazyFrame` backed by DuckDB"
)
raise ModuleNotFoundError(msg) from exc

from narwhals._arrow.dataframe import ArrowDataFrame

return ArrowDataFrame(
native_dataframe=self._native_frame.arrow(),
backend_version=parse_version(pa.__version__),
version=self._version,
)

elif return_type == "pandas":
import pandas as pd # ignore-banned-import

from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals.utils import Implementation

return PandasLikeDataFrame(
native_dataframe=self._native_frame.df(),
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
version=self._version,
)

elif return_type == "polars":
import polars as pl # ignore-banned-import

from narwhals._polars.dataframe import PolarsDataFrame
from narwhals.utils import Implementation

return PolarsDataFrame(
df=self._native_frame.pl(),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated but.. should we change in PolarsDataFrame:

- df: pl.DataFrame,
+ native_dataframe: pl.DataFrame,

backend_version=parse_version(pl.__version__),
version=self._version,
)

else:
msg = (
"Only the following `return_type`'s are supported: pyarrow, pandas and "
f"polars. Found '{return_type}'."
)
raise ValueError(msg)

def head(self, n: int) -> Self:
return self._from_native_frame(self._native_frame.limit(n))
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,11 +425,11 @@ def collect_schema(self: Self) -> dict[str, DType]:
for name, dtype in self._native_frame.collect_schema().items()
}

def collect(self: Self) -> PolarsDataFrame:
def collect(self: Self, **kwargs: Any) -> PolarsDataFrame:
import polars as pl

try:
result = self._native_frame.collect()
result = self._native_frame.collect(**kwargs)
except pl.exceptions.ColumnNotFoundError as e:
raise ColumnNotFoundError(str(e)) from e

Expand Down
127 changes: 106 additions & 21 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3610,16 +3610,40 @@ def __getitem__(self, item: str | slice) -> NoReturn:
msg = "Slicing is not supported on LazyFrame"
raise TypeError(msg)

def collect(self) -> DataFrame[Any]:
def collect(
self: Self,
*,
polars_kwargs: dict[str, Any] | None = None,
dask_kwargs: dict[str, Any] | None = None,
duckdb_kwargs: dict[str, str] | None = None,
Comment on lines +3616 to +3618
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These could all be TypedDict's πŸ‘€

) -> DataFrame[Any]:
r"""Materialize this LazyFrame into a DataFrame.

As each underlying lazyframe has different arguments to set when materializing
the lazyframe into a dataframe, we allow to pass them separately into its own
keyword argument.

Arguments:
polars_kwargs: [polars.LazyFrame.collect](https://docs.pola.rs/api/python/dev/reference/lazyframe/api/polars.LazyFrame.collect.html)
arguments. Used only if the `LazyFrame` is backed by a `polars.LazyFrame`.
If not provided, it uses the polars default values.
dask_kwargs: [dask.dataframe.DataFrame.compute](https://docs.dask.org/en/stable/generated/dask.dataframe.DataFrame.compute.html)
arguments. Used only if the `LazyFrame` is backed by a `dask.dataframe.DataFrame`.
If not provided, it uses the dask default values.
duckdb_kwargs: Allows to specify in which eager backend to materialize a
DuckDBPyRelation backed LazyFrame. It is possible to choose among
`pyarrow`, `pandas` or `polars` by declaring
`duckdb_kwargs={"return_type": "<eager_backend>"}`.

Returns:
DataFrame

Examples:
>>> import narwhals as nw
>>> import polars as pl
>>> import dask.dataframe as dd
>>> import narwhals as nw
>>> from narwhals.typing import IntoDataFrame, IntoFrame
>>>
>>> data = {
... "a": ["a", "b", "a", "b", "b", "c"],
... "b": [1, 2, 3, 4, 5, 6],
Expand All @@ -3628,28 +3652,14 @@ def collect(self) -> DataFrame[Any]:
>>> lf_pl = pl.LazyFrame(data)
>>> lf_dask = dd.from_dict(data, npartitions=2)

>>> lf = nw.from_native(lf_pl)
>>> lf # doctest:+ELLIPSIS
>>> nw.from_native(lf_pl) # doctest:+ELLIPSIS
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
| Narwhals LazyFrame |
|-----------------------------|
|<LazyFrame at ...
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
>>> df = lf.group_by("a").agg(nw.all().sum()).collect()
>>> df.to_native().sort("a")
shape: (3, 3)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”
β”‚ a ┆ b ┆ c β”‚
β”‚ --- ┆ --- ┆ --- β”‚
β”‚ str ┆ i64 ┆ i64 β”‚
β•žβ•β•β•β•β•β•ͺ═════β•ͺ═════║
β”‚ a ┆ 4 ┆ 10 β”‚
β”‚ b ┆ 11 ┆ 10 β”‚
β”‚ c ┆ 6 ┆ 1 β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

>>> lf = nw.from_native(lf_dask)
>>> lf
>>> nw.from_native(lf_dask)
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
| Narwhals LazyFrame |
|-----------------------------------|
Expand All @@ -3662,15 +3672,90 @@ def collect(self) -> DataFrame[Any]:
|Dask Name: frompandas, 1 expression|
|Expr=df |
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
>>> df = lf.group_by("a").agg(nw.col("b", "c").sum()).collect()
>>> df.to_native()

Let's define a dataframe-agnostic that does some grouping computation and
finally collects to a DataFrame:

>>> def agnostic_group_by_and_collect(lf_native: IntoFrame) -> IntoDataFrame:
... lf = nw.from_native(lf_native)
... return (
... lf.group_by("a")
... .agg(nw.col("b", "c").sum())
... .sort("a")
... .collect()
... .to_native()
... )

We can then pass any supported library such as Polars or Dask
to `agnostic_group_by_and_collect`:

>>> agnostic_group_by_and_collect(lf_pl)
shape: (3, 3)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”
β”‚ a ┆ b ┆ c β”‚
β”‚ --- ┆ --- ┆ --- β”‚
β”‚ str ┆ i64 ┆ i64 β”‚
β•žβ•β•β•β•β•β•ͺ═════β•ͺ═════║
β”‚ a ┆ 4 ┆ 10 β”‚
β”‚ b ┆ 11 ┆ 10 β”‚
β”‚ c ┆ 6 ┆ 1 β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

>>> agnostic_group_by_and_collect(lf_dask)
a b c
0 a 4 10
1 b 11 10
2 c 6 1

Now, let's suppose that we want to run lazily, yet without
query optimization (e.g. for debugging purpose). As this is achieved
differently in polars and dask, to keep a unified workflow we can specify
the native kwargs for each backend:

>>> def agnostic_collect_no_opt(lf_native: IntoFrame) -> IntoDataFrame:
... lf = nw.from_native(lf_native)
... return (
... lf.group_by("a")
... .agg(nw.col("b", "c").sum())
... .sort("a")
... .collect(
... polars_kwargs={"no_optimization": True},
... dask_kwargs={"optimize_graph": False},
... )
... .to_native()
... )

>>> agnostic_collect_no_opt(lf_pl)
shape: (3, 3)
β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”
β”‚ a ┆ b ┆ c β”‚
β”‚ --- ┆ --- ┆ --- β”‚
β”‚ str ┆ i64 ┆ i64 β”‚
β•žβ•β•β•β•β•β•ͺ═════β•ͺ═════║
β”‚ a ┆ 4 ┆ 10 β”‚
β”‚ b ┆ 11 ┆ 10 β”‚
β”‚ c ┆ 6 ┆ 1 β”‚
β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”˜

>>> agnostic_collect_no_opt(lf_dask)
a b c
0 a 4 10
1 b 11 10
2 c 6 1
"""
from narwhals.utils import Implementation

if self.implementation is Implementation.POLARS and polars_kwargs is not None:
kwargs = polars_kwargs
elif self.implementation is Implementation.DASK and dask_kwargs is not None:
kwargs = dask_kwargs
elif self.implementation is Implementation.DUCKDB and duckdb_kwargs is not None:
kwargs = duckdb_kwargs
else:
kwargs = {}

return self._dataframe(
self._compliant_frame.collect(),
self._compliant_frame.collect(**kwargs),
level="full",
)

Expand Down
Loading
Loading