Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Track is multi output #1855

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions narwhals/_arrow/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,3 @@ def call(df: ArrowDataFrame) -> list[ArrowSeries]:
)
else:
return self._to_expr() & other

def __invert__(self: Self) -> ArrowSelector:
return (
ArrowSelectorNamespace(
backend_version=self._backend_version, version=self._version
).all()
- self
)
8 changes: 0 additions & 8 deletions narwhals/_dask/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,11 +177,3 @@ def call(df: DaskLazyFrame) -> list[Any]:
)
else:
return self._to_expr() & other

def __invert__(self: Self) -> DaskSelector:
return (
DaskSelectorNamespace(
backend_version=self._backend_version, version=self._version
).all()
- self
)
97 changes: 84 additions & 13 deletions narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence
from typing import TypedDict
from typing import TypeVar
from typing import Union
from typing import cast
Expand All @@ -15,6 +16,7 @@
from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import InvalidIntoExprError
from narwhals.exceptions import LengthChangingExprError
from narwhals.exceptions import MultiOutputExprError
from narwhals.utils import Implementation

if TYPE_CHECKING:
Expand Down Expand Up @@ -44,6 +46,13 @@
T = TypeVar("T")


class ExprMetadata(TypedDict):
is_order_dependent: bool
changes_length: bool
aggregates: bool
is_multi_output: bool


def evaluate_into_expr(
df: CompliantDataFrame | CompliantLazyFrame,
into_expr: IntoCompliantExpr[CompliantSeriesT_co],
Expand Down Expand Up @@ -338,11 +347,46 @@ def extract_compliant(
return other


def arg_aggregates(arg: IntoExpr | Any) -> bool:
from narwhals.expr import Expr
from narwhals.series import Series

if isinstance(arg, Expr):
return arg._metadata["aggregates"]
if isinstance(arg, Series):
return arg.len() == 1
if isinstance(arg, str): # noqa: SIM103
# Column name, e.g. 'a', gets treated as `nw.col('a')`,
# which doesn't aggregate.
return False
# Scalar
return True


def arg_is_order_dependent(arg: IntoExpr | Any) -> bool:
from narwhals.expr import Expr

if isinstance(arg, Expr):
return arg._metadata["is_order_dependent"]
if isinstance(arg, str):
# Column name, e.g. 'a', gets treated as `nw.col('a')`,
Comment on lines +371 to +372
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sigh...this isn't even always true. darn. we may need to parse column names as expressions a bit earlier

# which doesn't change length.
return False
# Scalar or Series
# Series are an eager-only concept anyway and so the order-dependent
# restrictions don't apply to them anyway.
return False


def operation_is_order_dependent(*args: IntoExpr | Any) -> bool:
# If an arg is an Expr, we look at `_is_order_dependent`. If it isn't,
# it means that it was a scalar (e.g. nw.col('a') + 1) or a column name,
# neither of which is order-dependent, so we default to `False`.
return any(getattr(x, "_is_order_dependent", False) for x in args)
# If any arg is order-dependent, the whole expression is.
return any(arg_is_order_dependent(x) for x in args)


def operation_aggregates(*args: IntoExpr | Any) -> bool:
# If there's a mix of aggregates and non-aggregates, broadcasting
# will happen. The whole operation aggregates if all arguments aggregate.
return all(arg_aggregates(x) for x in args)


def operation_changes_length(*args: IntoExpr | Any) -> bool:
Expand All @@ -364,8 +408,17 @@ def operation_changes_length(*args: IntoExpr | Any) -> bool:
"""
from narwhals.expr import Expr

n_exprs = len([x for x in args if isinstance(x, Expr)])
changes_length = any(isinstance(x, Expr) and x._changes_length for x in args)
n_exprs = 0
changes_length = False
for arg in args:
if isinstance(arg, Expr):
n_exprs += 1
if arg._metadata["changes_length"]:
changes_length = True
elif isinstance(arg, str):
n_exprs += 1
# Note: Series are an eager-only concept anyway and so the length-changing
# restrictions don't apply to them anyway.
if n_exprs > 1 and changes_length:
msg = (
"Found multiple expressions at least one of which changes length.\n"
Expand All @@ -376,10 +429,28 @@ def operation_changes_length(*args: IntoExpr | Any) -> bool:
return changes_length


def operation_aggregates(*args: IntoExpr | Any) -> bool:
# If an arg is an Expr, we look at `_aggregates`. If it isn't,
# it means that it was a scalar (e.g. nw.col('a').sum() + 1),
# which is already length-1, so we default to `True`. If any
# expression does not aggregate, then broadcasting will take
# place and the result will not be an aggregate.
return all(getattr(x, "_aggregates", True) for x in args)
def operation_is_multi_output(*args: IntoExpr | Any) -> bool:
# Only the first expression is allowed to produce multiple outputs.
from narwhals.expr import Expr

if any(isinstance(x, Expr) and x._metadata["is_multi_output"] for x in args[1:]):
msg = (
"Multi-output expressions cannot appear in the right-hand-side of\n"
"any operation. For example, `nw.col('a', 'b') + nw.col('c')` is \n"
"allowed, but not `nw.col('a') + nw.col('b', 'c')`."
)
raise MultiOutputExprError(msg)
return isinstance(args[0], Expr) and args[0]._metadata["is_multi_output"]


def combine_metadata(
*args: IntoExpr | Any, is_multi_output: bool | None = None
) -> ExprMetadata:
return ExprMetadata(
is_order_dependent=operation_is_order_dependent(*args),
changes_length=operation_changes_length(*args),
aggregates=operation_aggregates(*args),
is_multi_output=is_multi_output
if is_multi_output is not None
else operation_is_multi_output(*args),
)
10 changes: 0 additions & 10 deletions narwhals/_pandas_like/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,3 @@ def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
)
else:
return self._to_expr() & other

def __invert__(self: Self) -> PandasSelector:
return (
PandasSelectorNamespace(
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
).all()
- self
)
8 changes: 5 additions & 3 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from narwhals.exceptions import LengthChangingExprError
from narwhals.exceptions import OrderDependentExprError
from narwhals.exceptions import ShapeError
from narwhals.expr import Expr
from narwhals.schema import Schema
from narwhals.translate import to_native
from narwhals.utils import find_stacklevel
Expand Down Expand Up @@ -154,7 +155,8 @@ def filter(
) -> Self:
flat_predicates = flatten(predicates)
if any(
getattr(x, "_aggregates", False) or getattr(x, "_changes_length", False)
isinstance(x, Expr)
and (x._metadata["aggregates"] or x._metadata["changes_length"])
for x in flat_predicates
):
msg = "Expressions which aggregate or change length cannot be passed to `filter`."
Expand Down Expand Up @@ -3667,7 +3669,7 @@ def _extract_compliant(self: Self, arg: Any) -> Any:
msg = "Binary operations between Series and LazyFrame are not supported."
raise TypeError(msg)
if isinstance(arg, Expr):
if arg._is_order_dependent:
if arg._metadata["is_order_dependent"]:
msg = (
"Order-dependent expressions are not supported for use in LazyFrame.\n\n"
"Hints:\n"
Expand All @@ -3678,7 +3680,7 @@ def _extract_compliant(self: Self, arg: Any) -> Any:
" they will be supported."
)
raise OrderDependentExprError(msg)
if arg._changes_length:
if arg._metadata["changes_length"]:
msg = (
"Length-changing expressions are not supported for use in LazyFrame, unless\n"
"followed by an aggregation.\n\n"
Expand Down
8 changes: 8 additions & 0 deletions narwhals/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ def __init__(self: Self, message: str) -> None:
super().__init__(self.message)


class MultiOutputExprError(ValueError):
"""Exception raised when trying to combine expressions where one has multiple outputs."""

def __init__(self, message: str) -> None:
self.message = message
super().__init__(self.message)


class UnsupportedDTypeError(ValueError):
"""Exception raised when trying to convert to a DType which is not supported by the given backend."""

Expand Down
Loading
Loading