diff --git a/narwhals/_arrow/selectors.py b/narwhals/_arrow/selectors.py index 36feb5d56..84ab1120b 100644 --- a/narwhals/_arrow/selectors.py +++ b/narwhals/_arrow/selectors.py @@ -169,11 +169,3 @@ def call(df: ArrowDataFrame) -> list[ArrowSeries]: ) else: return self._to_expr() & other - - def __invert__(self: Self) -> ArrowSelector: - return ( - ArrowSelectorNamespace( - backend_version=self._backend_version, version=self._version - ).all() - - self - ) diff --git a/narwhals/_dask/selectors.py b/narwhals/_dask/selectors.py index 9e6cc6302..e084f8dbe 100644 --- a/narwhals/_dask/selectors.py +++ b/narwhals/_dask/selectors.py @@ -177,11 +177,3 @@ def call(df: DaskLazyFrame) -> list[Any]: ) else: return self._to_expr() & other - - def __invert__(self: Self) -> DaskSelector: - return ( - DaskSelectorNamespace( - backend_version=self._backend_version, version=self._version - ).all() - - self - ) diff --git a/narwhals/_expression_parsing.py b/narwhals/_expression_parsing.py index a11db68eb..e8c2387a5 100644 --- a/narwhals/_expression_parsing.py +++ b/narwhals/_expression_parsing.py @@ -7,6 +7,7 @@ from typing import TYPE_CHECKING from typing import Any from typing import Sequence +from typing import TypedDict from typing import TypeVar from typing import Union from typing import cast @@ -15,6 +16,7 @@ from narwhals.dependencies import is_numpy_array from narwhals.exceptions import InvalidIntoExprError from narwhals.exceptions import LengthChangingExprError +from narwhals.exceptions import MultiOutputExprError from narwhals.utils import Implementation if TYPE_CHECKING: @@ -44,6 +46,13 @@ T = TypeVar("T") +class ExprMetadata(TypedDict): + is_order_dependent: bool + changes_length: bool + aggregates: bool + is_multi_output: bool + + def evaluate_into_expr( df: CompliantDataFrame | CompliantLazyFrame, into_expr: IntoCompliantExpr[CompliantSeriesT_co], @@ -338,11 +347,46 @@ def extract_compliant( return other +def arg_aggregates(arg: IntoExpr | Any) -> bool: + from narwhals.expr import Expr + from narwhals.series import Series + + if isinstance(arg, Expr): + return arg._metadata["aggregates"] + if isinstance(arg, Series): + return arg.len() == 1 + if isinstance(arg, str): # noqa: SIM103 + # Column name, e.g. 'a', gets treated as `nw.col('a')`, + # which doesn't aggregate. + return False + # Scalar + return True + + +def arg_is_order_dependent(arg: IntoExpr | Any) -> bool: + from narwhals.expr import Expr + + if isinstance(arg, Expr): + return arg._metadata["is_order_dependent"] + if isinstance(arg, str): + # Column name, e.g. 'a', gets treated as `nw.col('a')`, + # which doesn't change length. + return False + # Scalar or Series + # Series are an eager-only concept anyway and so the order-dependent + # restrictions don't apply to them anyway. + return False + + def operation_is_order_dependent(*args: IntoExpr | Any) -> bool: - # If an arg is an Expr, we look at `_is_order_dependent`. If it isn't, - # it means that it was a scalar (e.g. nw.col('a') + 1) or a column name, - # neither of which is order-dependent, so we default to `False`. - return any(getattr(x, "_is_order_dependent", False) for x in args) + # If any arg is order-dependent, the whole expression is. + return any(arg_is_order_dependent(x) for x in args) + + +def operation_aggregates(*args: IntoExpr | Any) -> bool: + # If there's a mix of aggregates and non-aggregates, broadcasting + # will happen. The whole operation aggregates if all arguments aggregate. + return all(arg_aggregates(x) for x in args) def operation_changes_length(*args: IntoExpr | Any) -> bool: @@ -364,8 +408,17 @@ def operation_changes_length(*args: IntoExpr | Any) -> bool: """ from narwhals.expr import Expr - n_exprs = len([x for x in args if isinstance(x, Expr)]) - changes_length = any(isinstance(x, Expr) and x._changes_length for x in args) + n_exprs = 0 + changes_length = False + for arg in args: + if isinstance(arg, Expr): + n_exprs += 1 + if arg._metadata["changes_length"]: + changes_length = True + elif isinstance(arg, str): + n_exprs += 1 + # Note: Series are an eager-only concept anyway and so the length-changing + # restrictions don't apply to them anyway. if n_exprs > 1 and changes_length: msg = ( "Found multiple expressions at least one of which changes length.\n" @@ -376,10 +429,28 @@ def operation_changes_length(*args: IntoExpr | Any) -> bool: return changes_length -def operation_aggregates(*args: IntoExpr | Any) -> bool: - # If an arg is an Expr, we look at `_aggregates`. If it isn't, - # it means that it was a scalar (e.g. nw.col('a').sum() + 1), - # which is already length-1, so we default to `True`. If any - # expression does not aggregate, then broadcasting will take - # place and the result will not be an aggregate. - return all(getattr(x, "_aggregates", True) for x in args) +def operation_is_multi_output(*args: IntoExpr | Any) -> bool: + # Only the first expression is allowed to produce multiple outputs. + from narwhals.expr import Expr + + if any(isinstance(x, Expr) and x._metadata["is_multi_output"] for x in args[1:]): + msg = ( + "Multi-output expressions cannot appear in the right-hand-side of\n" + "any operation. For example, `nw.col('a', 'b') + nw.col('c')` is \n" + "allowed, but not `nw.col('a') + nw.col('b', 'c')`." + ) + raise MultiOutputExprError(msg) + return isinstance(args[0], Expr) and args[0]._metadata["is_multi_output"] + + +def combine_metadata( + *args: IntoExpr | Any, is_multi_output: bool | None = None +) -> ExprMetadata: + return ExprMetadata( + is_order_dependent=operation_is_order_dependent(*args), + changes_length=operation_changes_length(*args), + aggregates=operation_aggregates(*args), + is_multi_output=is_multi_output + if is_multi_output is not None + else operation_is_multi_output(*args), + ) diff --git a/narwhals/_pandas_like/selectors.py b/narwhals/_pandas_like/selectors.py index b3518283f..4238a6474 100644 --- a/narwhals/_pandas_like/selectors.py +++ b/narwhals/_pandas_like/selectors.py @@ -178,13 +178,3 @@ def call(df: PandasLikeDataFrame) -> list[PandasLikeSeries]: ) else: return self._to_expr() & other - - def __invert__(self: Self) -> PandasSelector: - return ( - PandasSelectorNamespace( - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, - ).all() - - self - ) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index d9657a61a..ac3fe16fb 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -19,6 +19,7 @@ from narwhals.exceptions import LengthChangingExprError from narwhals.exceptions import OrderDependentExprError from narwhals.exceptions import ShapeError +from narwhals.expr import Expr from narwhals.schema import Schema from narwhals.translate import to_native from narwhals.utils import find_stacklevel @@ -154,7 +155,8 @@ def filter( ) -> Self: flat_predicates = flatten(predicates) if any( - getattr(x, "_aggregates", False) or getattr(x, "_changes_length", False) + isinstance(x, Expr) + and (x._metadata["aggregates"] or x._metadata["changes_length"]) for x in flat_predicates ): msg = "Expressions which aggregate or change length cannot be passed to `filter`." @@ -3667,7 +3669,7 @@ def _extract_compliant(self: Self, arg: Any) -> Any: msg = "Binary operations between Series and LazyFrame are not supported." raise TypeError(msg) if isinstance(arg, Expr): - if arg._is_order_dependent: + if arg._metadata["is_order_dependent"]: msg = ( "Order-dependent expressions are not supported for use in LazyFrame.\n\n" "Hints:\n" @@ -3678,7 +3680,7 @@ def _extract_compliant(self: Self, arg: Any) -> Any: " they will be supported." ) raise OrderDependentExprError(msg) - if arg._changes_length: + if arg._metadata["changes_length"]: msg = ( "Length-changing expressions are not supported for use in LazyFrame, unless\n" "followed by an aggregation.\n\n" diff --git a/narwhals/exceptions.py b/narwhals/exceptions.py index 817cca4e1..3ba149cec 100644 --- a/narwhals/exceptions.py +++ b/narwhals/exceptions.py @@ -104,6 +104,14 @@ def __init__(self: Self, message: str) -> None: super().__init__(self.message) +class MultiOutputExprError(ValueError): + """Exception raised when trying to combine expressions where one has multiple outputs.""" + + def __init__(self, message: str) -> None: + self.message = message + super().__init__(self.message) + + class UnsupportedDTypeError(ValueError): """Exception raised when trying to convert to a DType which is not supported by the given backend.""" diff --git a/narwhals/expr.py b/narwhals/expr.py index 1ef137900..81874440a 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -8,9 +8,9 @@ from typing import Mapping from typing import Sequence +from narwhals._expression_parsing import ExprMetadata +from narwhals._expression_parsing import combine_metadata from narwhals._expression_parsing import extract_compliant -from narwhals._expression_parsing import operation_aggregates -from narwhals._expression_parsing import operation_changes_length from narwhals._expression_parsing import operation_is_order_dependent from narwhals.dtypes import _validate_dtype from narwhals.expr_cat import ExprCatNamespace @@ -40,34 +40,26 @@ class Expr: def __init__( - self: Self, - to_compliant_expr: Callable[[Any], Any], - is_order_dependent: bool, # noqa: FBT001 - changes_length: bool, # noqa: FBT001 - aggregates: bool, # noqa: FBT001 + self: Self, to_compliant_expr: Callable[[Any], Any], metadata: ExprMetadata ) -> None: # callable from CompliantNamespace to CompliantExpr self._to_compliant_expr = to_compliant_expr - self._is_order_dependent = is_order_dependent - self._changes_length = changes_length - self._aggregates = aggregates + self._metadata = metadata def __repr__(self: Self) -> str: return ( "Narwhals Expr\n" - f"is_order_dependent: {self._is_order_dependent}\n" - f"changes_length: {self._changes_length}\n" - f"aggregates: {self._aggregates}" + f"is_order_dependent: {self._metadata['is_order_dependent']}\n" + f"changes_length: {self._metadata['changes_length']}\n" + f"aggregates: {self._metadata['aggregates']}\n" + f"is_multi_output: {self._metadata['is_multi_output']}" ) def _taxicab_norm(self: Self) -> Self: # This is just used to test out the stable api feature in a realistic-ish way. # It's not intended to be used. return self.__class__( - lambda plx: self._to_compliant_expr(plx).abs().sum(), - self._is_order_dependent, - self._changes_length, - self._aggregates, + lambda plx: self._to_compliant_expr(plx).abs().sum(), self._metadata ) # --- convert --- @@ -124,11 +116,11 @@ def alias(self: Self, name: str) -> Self: c: [[14,15]] """ + if self._metadata["is_multi_output"]: + msg = "Cannot alias multi-output expression. Use `.name.suffix`, `.name.map`" + raise ValueError(msg) return self.__class__( - lambda plx: self._to_compliant_expr(plx).alias(name), - is_order_dependent=self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + lambda plx: self._to_compliant_expr(plx).alias(name), self._metadata ) def pipe( @@ -254,10 +246,7 @@ def cast(self: Self, dtype: DType | type[DType]) -> Self: """ _validate_dtype(dtype) return self.__class__( - lambda plx: self._to_compliant_expr(plx).cast(dtype), - is_order_dependent=self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + lambda plx: self._to_compliant_expr(plx).cast(dtype), self._metadata ) # --- binary --- @@ -266,9 +255,7 @@ def __eq__(self: Self, other: object) -> Self: # type: ignore[override] lambda plx: self._to_compliant_expr(plx).__eq__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other), ) def __ne__(self: Self, other: object) -> Self: # type: ignore[override] @@ -276,9 +263,7 @@ def __ne__(self: Self, other: object) -> Self: # type: ignore[override] lambda plx: self._to_compliant_expr(plx).__ne__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other), ) def __and__(self: Self, other: Any) -> Self: @@ -286,9 +271,7 @@ def __and__(self: Self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__and__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other), ) def __rand__(self: Self, other: Any) -> Self: @@ -297,21 +280,14 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: extract_compliant(plx, self) ) - return self.__class__( - func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), - ) + return self.__class__(func, combine_metadata(self, other)) def __or__(self: Self, other: Any) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).__or__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other), ) def __ror__(self: Self, other: Any) -> Self: @@ -320,21 +296,14 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: extract_compliant(plx, self) ) - return self.__class__( - func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), - ) + return self.__class__(func, combine_metadata(self, other)) def __add__(self: Self, other: Any) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).__add__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other), ) def __radd__(self: Self, other: Any) -> Self: @@ -343,21 +312,14 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: extract_compliant(plx, self) ) - return self.__class__( - func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), - ) + return self.__class__(func, combine_metadata(self, other)) def __sub__(self: Self, other: Any) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).__sub__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other), ) def __rsub__(self: Self, other: Any) -> Self: @@ -366,21 +328,14 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: extract_compliant(plx, self) ) - return self.__class__( - func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), - ) + return self.__class__(func, combine_metadata(self, other)) def __truediv__(self: Self, other: Any) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).__truediv__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other), ) def __rtruediv__(self: Self, other: Any) -> Self: @@ -389,21 +344,14 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: extract_compliant(plx, self) ) - return self.__class__( - func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), - ) + return self.__class__(func, combine_metadata(self, other)) def __mul__(self: Self, other: Any) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).__mul__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other), ) def __rmul__(self: Self, other: Any) -> Self: @@ -412,21 +360,14 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: extract_compliant(plx, self) ) - return self.__class__( - func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), - ) + return self.__class__(func, combine_metadata(self, other)) def __le__(self: Self, other: Any) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).__le__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other), ) def __lt__(self: Self, other: Any) -> Self: @@ -434,9 +375,7 @@ def __lt__(self: Self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__lt__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other), ) def __gt__(self: Self, other: Any) -> Self: @@ -444,9 +383,7 @@ def __gt__(self: Self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__gt__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other), ) def __ge__(self: Self, other: Any) -> Self: @@ -454,9 +391,7 @@ def __ge__(self: Self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__ge__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other), ) def __pow__(self: Self, other: Any) -> Self: @@ -464,9 +399,7 @@ def __pow__(self: Self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).__pow__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other), ) def __rpow__(self: Self, other: Any) -> Self: @@ -475,21 +408,14 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: extract_compliant(plx, self) ) - return self.__class__( - func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), - ) + return self.__class__(func, combine_metadata(self, other)) def __floordiv__(self: Self, other: Any) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).__floordiv__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other), ) def __rfloordiv__(self: Self, other: Any) -> Self: @@ -498,21 +424,14 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: extract_compliant(plx, self) ) - return self.__class__( - func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), - ) + return self.__class__(func, combine_metadata(self, other)) def __mod__(self: Self, other: Any) -> Self: return self.__class__( lambda plx: self._to_compliant_expr(plx).__mod__( extract_compliant(plx, other) ), - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), + combine_metadata(self, other), ) def __rmod__(self: Self, other: Any) -> Self: @@ -521,20 +440,12 @@ def func(plx: CompliantNamespace[Any]) -> CompliantExpr[Any]: extract_compliant(plx, self) ) - return self.__class__( - func, - is_order_dependent=operation_is_order_dependent(self, other), - changes_length=operation_changes_length(self, other), - aggregates=operation_aggregates(self, other), - ) + return self.__class__(func, combine_metadata(self, other)) # --- unary --- def __invert__(self: Self) -> Self: return self.__class__( - lambda plx: self._to_compliant_expr(plx).__invert__(), - is_order_dependent=self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + lambda plx: self._to_compliant_expr(plx).__invert__(), self._metadata ) def any(self: Self) -> Self: @@ -588,9 +499,7 @@ def any(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).any(), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, "aggregates": True, "changes_length": False}), ) def all(self: Self) -> Self: @@ -644,9 +553,7 @@ def all(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).all(), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, "aggregates": True, "changes_length": False}), ) def ewm_mean( @@ -749,9 +656,7 @@ def ewm_mean( min_periods=min_periods, ignore_nulls=ignore_nulls, ), - is_order_dependent=self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata, ) def mean(self: Self) -> Self: @@ -805,9 +710,7 @@ def mean(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).mean(), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, "aggregates": True, "changes_length": False}), ) def median(self: Self) -> Self: @@ -864,9 +767,7 @@ def median(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).median(), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, "aggregates": True, "changes_length": False}), ) def std(self: Self, *, ddof: int = 1) -> Self: @@ -923,9 +824,7 @@ def std(self: Self, *, ddof: int = 1) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).std(ddof=ddof), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, "aggregates": True, "changes_length": False}), ) def var(self: Self, *, ddof: int = 1) -> Self: @@ -983,9 +882,7 @@ def var(self: Self, *, ddof: int = 1) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).var(ddof=ddof), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, "aggregates": True, "changes_length": False}), ) def map_batches( @@ -1062,9 +959,12 @@ def map_batches( function=function, return_dtype=return_dtype ), # safest assumptions - is_order_dependent=True, - changes_length=True, - aggregates=False, + ExprMetadata( + is_order_dependent=True, + changes_length=True, + aggregates=False, + is_multi_output=True, + ), ) def skew(self: Self) -> Self: @@ -1118,9 +1018,7 @@ def skew(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).skew(), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, "aggregates": True, "changes_length": False}), ) def sum(self: Self) -> Expr: @@ -1172,9 +1070,7 @@ def sum(self: Self) -> Expr: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).sum(), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, "aggregates": True, "changes_length": False}), ) def min(self: Self) -> Self: @@ -1228,9 +1124,7 @@ def min(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).min(), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, "aggregates": True, "changes_length": False}), ) def max(self: Self) -> Self: @@ -1284,9 +1178,7 @@ def max(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).max(), - is_order_dependent=self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, "aggregates": True, "changes_length": False}), ) def arg_min(self: Self) -> Self: @@ -1342,9 +1234,12 @@ def arg_min(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).arg_min(), - is_order_dependent=True, - changes_length=False, - aggregates=True, + ExprMetadata( + is_order_dependent=True, + changes_length=False, + aggregates=True, + is_multi_output=self._metadata["is_multi_output"], + ), ) def arg_max(self: Self) -> Self: @@ -1400,9 +1295,12 @@ def arg_max(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).arg_max(), - is_order_dependent=True, - changes_length=False, - aggregates=True, + ExprMetadata( + is_order_dependent=True, + changes_length=False, + aggregates=True, + is_multi_output=self._metadata["is_multi_output"], + ), ) def count(self: Self) -> Self: @@ -1456,9 +1354,7 @@ def count(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).count(), - self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, "changes_length": False, "aggregates": True}), ) def n_unique(self: Self) -> Self: @@ -1510,9 +1406,7 @@ def n_unique(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).n_unique(), - self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, "changes_length": False, "aggregates": True}), ) def unique(self: Self) -> Self: @@ -1566,9 +1460,7 @@ def unique(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).unique(), - self._is_order_dependent, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, "changes_length": True}), ) def abs(self: Self) -> Self: @@ -1623,10 +1515,7 @@ def abs(self: Self) -> Self: b: [[3,4]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).abs(), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + lambda plx: self._to_compliant_expr(plx).abs(), self._metadata ) def cum_sum(self: Self, *, reverse: bool = False) -> Self: @@ -1689,9 +1578,7 @@ def cum_sum(self: Self, *, reverse: bool = False) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_sum(reverse=reverse), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, "is_order_dependent": True}), ) def diff(self: Self) -> Self: @@ -1760,9 +1647,7 @@ def diff(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).diff(), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, "is_order_dependent": True}), ) def shift(self: Self, n: int) -> Self: @@ -1834,9 +1719,7 @@ def shift(self: Self, n: int) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).shift(n), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, "is_order_dependent": True}), ) def replace_strict( @@ -1929,9 +1812,7 @@ def replace_strict( lambda plx: self._to_compliant_expr(plx).replace_strict( old, new, return_dtype=return_dtype ), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata, ) def sort(self: Self, *, descending: bool = False, nulls_last: bool = False) -> Self: @@ -1962,9 +1843,7 @@ def sort(self: Self, *, descending: bool = False, nulls_last: bool = False) -> S lambda plx: self._to_compliant_expr(plx).sort( descending=descending, nulls_last=nulls_last ), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, "is_order_dependent": True}), ) # --- transform --- @@ -2039,11 +1918,7 @@ def is_between( extract_compliant(plx, upper_bound), closed, ), - is_order_dependent=operation_is_order_dependent( - self, lower_bound, upper_bound - ), - changes_length=self._changes_length, - aggregates=self._aggregates, + combine_metadata(self, lower_bound, upper_bound), ) def is_in(self: Self, other: Any) -> Self: @@ -2109,9 +1984,7 @@ def is_in(self: Self, other: Any) -> Self: lambda plx: self._to_compliant_expr(plx).is_in( extract_compliant(plx, other) ), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + combine_metadata(self, other), ) else: msg = "Narwhals `is_in` doesn't accept expressions as an argument, as opposed to Polars. You should provide an iterable instead." @@ -2181,9 +2054,13 @@ def filter(self: Self, *predicates: Any) -> Self: lambda plx: self._to_compliant_expr(plx).filter( *[extract_compliant(plx, pred) for pred in flat_predicates], ), - is_order_dependent=operation_is_order_dependent(*flat_predicates), - changes_length=True, - aggregates=self._aggregates, + ExprMetadata( + { + **combine_metadata(self, *flat_predicates), + "is_order_dependent": operation_is_order_dependent(*flat_predicates), + "changes_length": True, + } + ), ) def is_null(self: Self) -> Self: @@ -2263,10 +2140,7 @@ def is_null(self: Self) -> Self: b_is_null: [[false,false,true,false,false]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).is_null(), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + lambda plx: self._to_compliant_expr(plx).is_null(), self._metadata ) def is_nan(self: Self) -> Self: @@ -2333,10 +2207,7 @@ def is_nan(self: Self) -> Self: divided_is_nan: [[true,null,false]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).is_nan(), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + lambda plx: self._to_compliant_expr(plx).is_nan(), self._metadata ) def arg_true(self: Self) -> Self: @@ -2353,9 +2224,9 @@ def arg_true(self: Self) -> Self: issue_deprecation_warning(msg, _version="1.23.0") return self.__class__( lambda plx: self._to_compliant_expr(plx).arg_true(), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata( + {**self._metadata, "changes_length": True, "is_order_dependent": True} + ), ) def fill_null( @@ -2499,9 +2370,7 @@ def fill_null( lambda plx: self._to_compliant_expr(plx).fill_null( value=value, strategy=strategy, limit=limit ), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + self._metadata, ) # --- partial reduction --- @@ -2564,9 +2433,7 @@ def drop_nulls(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).drop_nulls(), - self._is_order_dependent, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, "changes_length": True}), ) def sample( @@ -2607,9 +2474,7 @@ def sample( lambda plx: self._to_compliant_expr(plx).sample( n, fraction=fraction, with_replacement=with_replacement, seed=seed ), - self._is_order_dependent, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, "changes_length": True}), ) def over(self: Self, *keys: str | Iterable[str]) -> Self: @@ -2700,10 +2565,7 @@ def over(self: Self, *keys: str | Iterable[str]) -> Self: └─────┴─────┴─────┘ """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).over(flatten(keys)), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + lambda plx: self._to_compliant_expr(plx).over(flatten(keys)), self._metadata ) def is_duplicated(self: Self) -> Self: @@ -2762,10 +2624,7 @@ def is_duplicated(self: Self) -> Self: b: [[true,true,false,false]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).is_duplicated(), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + lambda plx: self._to_compliant_expr(plx).is_duplicated(), self._metadata ) def is_unique(self: Self) -> Self: @@ -2824,10 +2683,7 @@ def is_unique(self: Self) -> Self: b: [[false,false,true,true]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).is_unique(), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + lambda plx: self._to_compliant_expr(plx).is_unique(), self._metadata ) def null_count(self: Self) -> Self: @@ -2885,10 +2741,7 @@ def null_count(self: Self) -> Self: b: [[2]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).null_count(), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + lambda plx: self._to_compliant_expr(plx).null_count(), self._metadata ) def is_first_distinct(self: Self) -> Self: @@ -2948,9 +2801,7 @@ def is_first_distinct(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).is_first_distinct(), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, "is_order_dependent": True}), ) def is_last_distinct(self: Self) -> Self: @@ -3010,9 +2861,7 @@ def is_last_distinct(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).is_last_distinct(), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, "is_order_dependent": True}), ) def quantile( @@ -3083,9 +2932,7 @@ def quantile( """ return self.__class__( lambda plx: self._to_compliant_expr(plx).quantile(quantile, interpolation), - self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, "changes_length": False, "aggregates": True}), ) def head(self: Self, n: int = 10) -> Self: @@ -3113,9 +2960,9 @@ def head(self: Self, n: int = 10) -> Self: issue_deprecation_warning(msg, _version="1.22.0") return self.__class__( lambda plx: self._to_compliant_expr(plx).head(n), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata( + {**self._metadata, "changes_length": True, "is_order_dependent": True} + ), ) def tail(self: Self, n: int = 10) -> Self: @@ -3143,9 +2990,9 @@ def tail(self: Self, n: int = 10) -> Self: issue_deprecation_warning(msg, _version="1.22.0") return self.__class__( lambda plx: self._to_compliant_expr(plx).tail(n), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata( + {**self._metadata, "changes_length": True, "is_order_dependent": True} + ), ) def round(self: Self, decimals: int = 0) -> Self: @@ -3212,10 +3059,7 @@ def round(self: Self, decimals: int = 0) -> Self: a: [[1.1,2.6,3.9]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).round(decimals), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + lambda plx: self._to_compliant_expr(plx).round(decimals), self._metadata ) def len(self: Self) -> Self: @@ -3275,9 +3119,7 @@ def len(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).len(), - self._is_order_dependent, - changes_length=False, - aggregates=True, + ExprMetadata({**self._metadata, "changes_length": False, "aggregates": True}), ) def gather_every(self: Self, n: int, offset: int = 0) -> Self: @@ -3306,9 +3148,9 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: issue_deprecation_warning(msg, _version="1.22.0") return self.__class__( lambda plx: self._to_compliant_expr(plx).gather_every(n=n, offset=offset), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata( + {**self._metadata, "changes_length": True, "is_order_dependent": True} + ), ) # need to allow numeric typing @@ -3456,11 +3298,7 @@ def clip( extract_compliant(plx, lower_bound), extract_compliant(plx, upper_bound), ), - is_order_dependent=operation_is_order_dependent( - self, lower_bound, upper_bound - ), - changes_length=self._changes_length, - aggregates=self._aggregates, + ExprMetadata({**combine_metadata(self, lower_bound, upper_bound)}), ) def mode(self: Self) -> Self: @@ -3517,9 +3355,7 @@ def mode(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).mode(), - self._is_order_dependent, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, "changes_length": True}), ) def is_finite(self: Self) -> Self: @@ -3581,10 +3417,7 @@ def is_finite(self: Self) -> Self: a: [[false,false,true,null]] """ return self.__class__( - lambda plx: self._to_compliant_expr(plx).is_finite(), - self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, + lambda plx: self._to_compliant_expr(plx).is_finite(), self._metadata ) def cum_count(self: Self, *, reverse: bool = False) -> Self: @@ -3652,9 +3485,7 @@ def cum_count(self: Self, *, reverse: bool = False) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_count(reverse=reverse), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, "is_order_dependent": True}), ) def cum_min(self: Self, *, reverse: bool = False) -> Self: @@ -3722,9 +3553,7 @@ def cum_min(self: Self, *, reverse: bool = False) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_min(reverse=reverse), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, "is_order_dependent": True}), ) def cum_max(self: Self, *, reverse: bool = False) -> Self: @@ -3792,9 +3621,7 @@ def cum_max(self: Self, *, reverse: bool = False) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_max(reverse=reverse), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, "is_order_dependent": True}), ) def cum_prod(self: Self, *, reverse: bool = False) -> Self: @@ -3862,9 +3689,7 @@ def cum_prod(self: Self, *, reverse: bool = False) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).cum_prod(reverse=reverse), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, "is_order_dependent": True}), ) def rolling_sum( @@ -3959,9 +3784,7 @@ def rolling_sum( min_periods=min_periods, center=center, ), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, "is_order_dependent": True}), ) def rolling_mean( @@ -4056,9 +3879,7 @@ def rolling_mean( min_periods=min_periods, center=center, ), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, "is_order_dependent": True}), ) def rolling_var( @@ -4153,9 +3974,7 @@ def rolling_var( lambda plx: self._to_compliant_expr(plx).rolling_var( window_size=window_size, min_periods=min_periods, center=center, ddof=ddof ), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, "is_order_dependent": True}), ) def rolling_std( @@ -4253,9 +4072,7 @@ def rolling_std( center=center, ddof=ddof, ), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, "is_order_dependent": True}), ) def rank( @@ -4353,9 +4170,7 @@ def rank( lambda plx: self._to_compliant_expr(plx).rank( method=method, descending=descending ), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, "is_order_dependent": True}), ) @property diff --git a/narwhals/expr_cat.py b/narwhals/expr_cat.py index 16dbb3929..3eb09c91d 100644 --- a/narwhals/expr_cat.py +++ b/narwhals/expr_cat.py @@ -4,6 +4,8 @@ from typing import Generic from typing import TypeVar +from narwhals._expression_parsing import ExprMetadata + if TYPE_CHECKING: from typing_extensions import Self @@ -63,7 +65,5 @@ def get_categories(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).cat.get_categories(), - self._expr._is_order_dependent, - changes_length=True, - aggregates=self._expr._aggregates, + ExprMetadata({**self._expr._metadata, "changes_length": True}), # type: ignore[typeddict-item] ) diff --git a/narwhals/expr_dt.py b/narwhals/expr_dt.py index 6ea1fbbdd..582017dd5 100644 --- a/narwhals/expr_dt.py +++ b/narwhals/expr_dt.py @@ -71,10 +71,7 @@ def date(self: Self) -> ExprT: a: [[2012-01-07,2023-03-10]] """ return self._expr.__class__( - lambda plx: self._expr._to_compliant_expr(plx).dt.date(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + lambda plx: self._expr._to_compliant_expr(plx).dt.date(), self._expr._metadata ) def year(self: Self) -> ExprT: @@ -142,10 +139,7 @@ def year(self: Self) -> ExprT: year: [[1978,2024,2065]] """ return self._expr.__class__( - lambda plx: self._expr._to_compliant_expr(plx).dt.year(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + lambda plx: self._expr._to_compliant_expr(plx).dt.year(), self._expr._metadata ) def month(self: Self) -> ExprT: @@ -214,9 +208,7 @@ def month(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.month(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def day(self: Self) -> ExprT: @@ -284,10 +276,7 @@ def day(self: Self) -> ExprT: day: [[1,13,1]] """ return self._expr.__class__( - lambda plx: self._expr._to_compliant_expr(plx).dt.day(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + lambda plx: self._expr._to_compliant_expr(plx).dt.day(), self._expr._metadata ) def hour(self: Self) -> ExprT: @@ -355,10 +344,7 @@ def hour(self: Self) -> ExprT: hour: [[1,5,10]] """ return self._expr.__class__( - lambda plx: self._expr._to_compliant_expr(plx).dt.hour(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + lambda plx: self._expr._to_compliant_expr(plx).dt.hour(), self._expr._metadata ) def minute(self: Self) -> ExprT: @@ -427,9 +413,7 @@ def minute(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.minute(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def second(self: Self) -> ExprT: @@ -496,9 +480,7 @@ def second(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.second(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def millisecond(self: Self) -> ExprT: @@ -565,9 +547,7 @@ def millisecond(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.millisecond(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def microsecond(self: Self) -> ExprT: @@ -634,9 +614,7 @@ def microsecond(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.microsecond(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def nanosecond(self: Self) -> ExprT: @@ -703,9 +681,7 @@ def nanosecond(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.nanosecond(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def ordinal_day(self: Self) -> ExprT: @@ -764,9 +740,7 @@ def ordinal_day(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.ordinal_day(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def weekday(self: Self) -> ExprT: @@ -823,9 +797,7 @@ def weekday(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.weekday(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def total_minutes(self: Self) -> ExprT: @@ -889,9 +861,7 @@ def total_minutes(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_minutes(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def total_seconds(self: Self) -> ExprT: @@ -955,9 +925,7 @@ def total_seconds(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_seconds(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def total_milliseconds(self: Self) -> ExprT: @@ -1026,9 +994,7 @@ def total_milliseconds(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_milliseconds(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def total_microseconds(self: Self) -> ExprT: @@ -1097,9 +1063,7 @@ def total_microseconds(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_microseconds(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def total_nanoseconds(self: Self) -> ExprT: @@ -1155,9 +1119,7 @@ def total_nanoseconds(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.total_nanoseconds(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def to_string(self: Self, format: str) -> ExprT: # noqa: A002 @@ -1256,9 +1218,7 @@ def to_string(self: Self, format: str) -> ExprT: # noqa: A002 """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.to_string(format), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def replace_time_zone(self: Self, time_zone: str | None) -> ExprT: @@ -1325,9 +1285,7 @@ def replace_time_zone(self: Self, time_zone: str | None) -> ExprT: lambda plx: self._expr._to_compliant_expr(plx).dt.replace_time_zone( time_zone ), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def convert_time_zone(self: Self, time_zone: str) -> ExprT: @@ -1400,9 +1358,7 @@ def convert_time_zone(self: Self, time_zone: str) -> ExprT: lambda plx: self._expr._to_compliant_expr(plx).dt.convert_time_zone( time_zone ), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def timestamp(self: Self, time_unit: Literal["ns", "us", "ms"] = "us") -> ExprT: @@ -1476,7 +1432,5 @@ def timestamp(self: Self, time_unit: Literal["ns", "us", "ms"] = "us") -> ExprT: raise ValueError(msg) return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).dt.timestamp(time_unit), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) diff --git a/narwhals/expr_list.py b/narwhals/expr_list.py index 0532db5fe..fc6a1227b 100644 --- a/narwhals/expr_list.py +++ b/narwhals/expr_list.py @@ -74,7 +74,5 @@ def len(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).list.len(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) diff --git a/narwhals/expr_name.py b/narwhals/expr_name.py index 706f9427d..975eed7d2 100644 --- a/narwhals/expr_name.py +++ b/narwhals/expr_name.py @@ -60,9 +60,7 @@ def keep(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.keep(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def map(self: Self, function: Callable[[str], str]) -> ExprT: @@ -112,9 +110,7 @@ def map(self: Self, function: Callable[[str], str]) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.map(function), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def prefix(self: Self, prefix: str) -> ExprT: @@ -163,9 +159,7 @@ def prefix(self: Self, prefix: str) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.prefix(prefix), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def suffix(self: Self, suffix: str) -> ExprT: @@ -214,9 +208,7 @@ def suffix(self: Self, suffix: str) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.suffix(suffix), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def to_lowercase(self: Self) -> ExprT: @@ -262,9 +254,7 @@ def to_lowercase(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.to_lowercase(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def to_uppercase(self: Self) -> ExprT: @@ -310,7 +300,5 @@ def to_uppercase(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).name.to_uppercase(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) diff --git a/narwhals/expr_str.py b/narwhals/expr_str.py index 67de3b131..8bcf3a6d7 100644 --- a/narwhals/expr_str.py +++ b/narwhals/expr_str.py @@ -77,9 +77,7 @@ def len_chars(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.len_chars(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def replace( @@ -146,9 +144,7 @@ def replace( lambda plx: self._expr._to_compliant_expr(plx).str.replace( pattern, value, literal=literal, n=n ), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def replace_all( @@ -214,9 +210,7 @@ def replace_all( lambda plx: self._expr._to_compliant_expr(plx).str.replace_all( pattern, value, literal=literal ), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def strip_chars(self: Self, characters: str | None = None) -> ExprT: @@ -265,9 +259,7 @@ def strip_chars(self: Self, characters: str | None = None) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.strip_chars(characters), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def starts_with(self: Self, prefix: str) -> ExprT: @@ -330,9 +322,7 @@ def starts_with(self: Self, prefix: str) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.starts_with(prefix), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def ends_with(self: Self, suffix: str) -> ExprT: @@ -395,9 +385,7 @@ def ends_with(self: Self, suffix: str) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.ends_with(suffix), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def contains(self: Self, pattern: str, *, literal: bool = False) -> ExprT: @@ -476,9 +464,7 @@ def contains(self: Self, pattern: str, *, literal: bool = False) -> ExprT: lambda plx: self._expr._to_compliant_expr(plx).str.contains( pattern, literal=literal ), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def slice(self: Self, offset: int, length: int | None = None) -> ExprT: @@ -581,9 +567,7 @@ def slice(self: Self, offset: int, length: int | None = None) -> ExprT: lambda plx: self._expr._to_compliant_expr(plx).str.slice( offset=offset, length=length ), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def head(self: Self, n: int = 5) -> ExprT: @@ -651,9 +635,7 @@ def head(self: Self, n: int = 5) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.slice(0, n), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def tail(self: Self, n: int = 5) -> ExprT: @@ -723,9 +705,7 @@ def tail(self: Self, n: int = 5) -> ExprT: lambda plx: self._expr._to_compliant_expr(plx).str.slice( offset=-n, length=None ), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def to_datetime(self: Self, format: str | None = None) -> ExprT: # noqa: A002 @@ -795,9 +775,7 @@ def to_datetime(self: Self, format: str | None = None) -> ExprT: # noqa: A002 """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.to_datetime(format=format), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def to_uppercase(self: Self) -> ExprT: @@ -862,9 +840,7 @@ def to_uppercase(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.to_uppercase(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) def to_lowercase(self: Self) -> ExprT: @@ -924,7 +900,5 @@ def to_lowercase(self: Self) -> ExprT: """ return self._expr.__class__( lambda plx: self._expr._to_compliant_expr(plx).str.to_lowercase(), - self._expr._is_order_dependent, - changes_length=self._expr._changes_length, - aggregates=self._expr._aggregates, + self._expr._metadata, ) diff --git a/narwhals/functions.py b/narwhals/functions.py index 090f4b870..9e4f28064 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -12,10 +12,9 @@ from typing import Union from typing import overload +from narwhals._expression_parsing import ExprMetadata +from narwhals._expression_parsing import combine_metadata from narwhals._expression_parsing import extract_compliant -from narwhals._expression_parsing import operation_aggregates -from narwhals._expression_parsing import operation_changes_length -from narwhals._expression_parsing import operation_is_order_dependent from narwhals._pandas_like.utils import broadcast_align_and_extract_native from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame @@ -1393,11 +1392,20 @@ def col(*names: str | Iterable[str]) -> Expr: ---- a: [[3,8]] """ + flat_names = flatten(names) def func(plx: Any) -> Any: - return plx.col(*flatten(names)) + return plx.col(*flat_names) - return Expr(func, is_order_dependent=False, changes_length=False, aggregates=False) + return Expr( + func, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=len(flat_names) > 1, + ), + ) def nth(*indices: int | Sequence[int]) -> Expr: @@ -1455,11 +1463,20 @@ def nth(*indices: int | Sequence[int]) -> Expr: ---- a: [[2,4]] """ + flat_indices = flatten(indices) def func(plx: Any) -> Any: - return plx.nth(*flatten(indices)) + return plx.nth(*flat_indices) - return Expr(func, is_order_dependent=False, changes_length=False, aggregates=False) + return Expr( + func, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=len(flat_indices) > 1, + ), + ) # Add underscore so it doesn't conflict with builtin `all` @@ -1518,9 +1535,12 @@ def all_() -> Expr: """ return Expr( lambda plx: plx.all(), - is_order_dependent=False, - changes_length=False, - aggregates=False, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), ) @@ -1574,7 +1594,15 @@ def len_() -> Expr: def func(plx: Any) -> Any: return plx.len() - return Expr(func, is_order_dependent=False, changes_length=False, aggregates=True) + return Expr( + func, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=True, + is_multi_output=False, + ), + ) def sum(*columns: str) -> Expr: @@ -1632,9 +1660,12 @@ def sum(*columns: str) -> Expr: """ return Expr( lambda plx: plx.col(*columns).sum(), - is_order_dependent=False, - changes_length=False, - aggregates=True, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=True, + is_multi_output=len(columns) > 1, + ), ) @@ -1693,9 +1724,12 @@ def mean(*columns: str) -> Expr: """ return Expr( lambda plx: plx.col(*columns).mean(), - is_order_dependent=False, - changes_length=False, - aggregates=True, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=True, + is_multi_output=len(columns) > 1, + ), ) @@ -1756,9 +1790,12 @@ def median(*columns: str) -> Expr: """ return Expr( lambda plx: plx.col(*columns).median(), - is_order_dependent=False, - changes_length=False, - aggregates=True, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=True, + is_multi_output=len(columns) > 1, + ), ) @@ -1817,9 +1854,12 @@ def min(*columns: str) -> Expr: """ return Expr( lambda plx: plx.col(*columns).min(), - is_order_dependent=False, - changes_length=False, - aggregates=True, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=True, + is_multi_output=len(columns) > 1, + ), ) @@ -1878,9 +1918,12 @@ def max(*columns: str) -> Expr: """ return Expr( lambda plx: plx.col(*columns).max(), - is_order_dependent=False, - changes_length=False, - aggregates=True, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=True, + is_multi_output=len(columns) > 1, + ), ) @@ -1947,9 +1990,7 @@ def sum_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.sum_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - is_order_dependent=operation_is_order_dependent(*flat_exprs), - changes_length=operation_changes_length(*flat_exprs), - aggregates=operation_aggregates(*flat_exprs), + combine_metadata(*exprs, is_multi_output=False), ) @@ -2019,9 +2060,7 @@ def min_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.min_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - is_order_dependent=operation_is_order_dependent(*flat_exprs), - changes_length=operation_changes_length(*flat_exprs), - aggregates=operation_aggregates(*flat_exprs), + combine_metadata(*exprs, is_multi_output=False), ) @@ -2091,9 +2130,7 @@ def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.max_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - is_order_dependent=operation_is_order_dependent(*flat_exprs), - changes_length=operation_changes_length(*flat_exprs), - aggregates=operation_aggregates(*flat_exprs), + combine_metadata(*exprs, is_multi_output=False), ) @@ -2104,8 +2141,9 @@ def __init__(self: Self, *predicates: IntoExpr | Iterable[IntoExpr]) -> None: msg = "At least one predicate needs to be provided to `narwhals.when`." raise TypeError(msg) if any( - getattr(x, "_aggregates", False) or getattr(x, "_changes_length", False) + x._metadata["aggregates"] or x._metadata["changes_length"] for x in self._predicates + if isinstance(x, Expr) ): msg = "Expressions which aggregate or change length cannot be passed to `filter`." raise ShapeError(msg) @@ -2118,9 +2156,7 @@ def then(self: Self, value: IntoExpr | Any) -> Then: lambda plx: plx.when(*self._extract_predicates(plx)).then( extract_compliant(plx, value) ), - is_order_dependent=operation_is_order_dependent(*self._predicates, value), - changes_length=operation_changes_length(*self._predicates, value), - aggregates=operation_aggregates(*self._predicates, value), + combine_metadata(*self._predicates, value), ) @@ -2130,9 +2166,7 @@ def otherwise(self: Self, value: IntoExpr | Any) -> Expr: lambda plx: self._to_compliant_expr(plx).otherwise( extract_compliant(plx, value) ), - is_order_dependent=operation_is_order_dependent(self, value), - changes_length=operation_changes_length(self, value), - aggregates=operation_aggregates(self, value), + combine_metadata(self, value), ) @@ -2285,9 +2319,7 @@ def all_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.all_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - is_order_dependent=operation_is_order_dependent(*flat_exprs), - changes_length=operation_changes_length(*flat_exprs), - aggregates=operation_aggregates(*flat_exprs), + combine_metadata(*exprs, is_multi_output=False), ) @@ -2360,9 +2392,12 @@ def lit(value: Any, dtype: DType | type[DType] | None = None) -> Expr: return Expr( lambda plx: plx.lit(value, dtype), - is_order_dependent=False, - changes_length=False, - aggregates=True, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=True, + is_multi_output=False, + ), ) @@ -2440,9 +2475,7 @@ def any_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.any_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - is_order_dependent=operation_is_order_dependent(*flat_exprs), - changes_length=operation_changes_length(*flat_exprs), - aggregates=operation_aggregates(*flat_exprs), + combine_metadata(*exprs, is_multi_output=False), ) @@ -2512,9 +2545,7 @@ def mean_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: flat_exprs = flatten(exprs) return Expr( lambda plx: plx.mean_horizontal(*[extract_compliant(plx, v) for v in flat_exprs]), - is_order_dependent=operation_is_order_dependent(*flat_exprs), - changes_length=operation_changes_length(*flat_exprs), - aggregates=operation_aggregates(*flat_exprs), + combine_metadata(*exprs, is_multi_output=False), ) @@ -2604,7 +2635,5 @@ def concat_str( separator=separator, ignore_nulls=ignore_nulls, ), - is_order_dependent=operation_is_order_dependent(*flat_exprs, *more_exprs), - changes_length=operation_changes_length(*flat_exprs, *more_exprs), - aggregates=operation_aggregates(*flat_exprs, *more_exprs), + combine_metadata(*flat_exprs, *more_exprs), ) diff --git a/narwhals/group_by.py b/narwhals/group_by.py index 11ad5798d..255b82590 100644 --- a/narwhals/group_by.py +++ b/narwhals/group_by.py @@ -11,6 +11,8 @@ from narwhals.dataframe import DataFrame from narwhals.dataframe import LazyFrame from narwhals.exceptions import InvalidOperationError +from narwhals.expr import Expr +from narwhals.utils import flatten from narwhals.utils import tupleify if TYPE_CHECKING: @@ -112,8 +114,11 @@ def agg( │ c ┆ 3 ┆ 1 │ └─────┴─────┴─────┘ """ - if not all(getattr(x, "_aggregates", True) for x in aggs) and all( - getattr(x, "_aggregates", True) for x in named_aggs.values() + flat_aggs = flatten(aggs) + if not all( + isinstance(x, Expr) and x._metadata["aggregates"] for x in flat_aggs + ) and all( + isinstance(x, Expr) and x._metadata["aggregates"] for x in named_aggs.values() ): msg = ( "Found expression which does not aggregate.\n\n" @@ -122,7 +127,7 @@ def agg( "but `df.group_by('a').agg(nw.col('b'))` is not." ) raise InvalidOperationError(msg) - aggs, named_aggs = self._df._flatten_and_extract(*aggs, **named_aggs) + aggs, named_aggs = self._df._flatten_and_extract(*flat_aggs, **named_aggs) return self._df._from_compliant_dataframe( # type: ignore[return-value] self._grouped.agg(*aggs, **named_aggs), ) @@ -208,8 +213,11 @@ def agg( │ c ┆ 3 ┆ 1 │ └─────┴─────┴─────┘ """ - if not all(getattr(x, "_aggregates", True) for x in aggs) and all( - getattr(x, "_aggregates", True) for x in named_aggs.values() + flat_aggs = flatten(aggs) + if not all( + isinstance(x, Expr) and x._metadata["aggregates"] for x in flat_aggs + ) and all( + isinstance(x, Expr) and x._metadata["aggregates"] for x in named_aggs.values() ): msg = ( "Found expression which does not aggregate.\n\n" @@ -218,7 +226,7 @@ def agg( "but `df.group_by('a').agg(nw.col('b'))` is not." ) raise InvalidOperationError(msg) - aggs, named_aggs = self._df._flatten_and_extract(*aggs, **named_aggs) + aggs, named_aggs = self._df._flatten_and_extract(*flat_aggs, **named_aggs) return self._df._from_compliant_dataframe( # type: ignore[return-value] self._grouped.agg(*aggs, **named_aggs), ) diff --git a/narwhals/selectors.py b/narwhals/selectors.py index dabf6f83f..3524bfd43 100644 --- a/narwhals/selectors.py +++ b/narwhals/selectors.py @@ -4,6 +4,8 @@ from typing import Any from typing import NoReturn +from narwhals._expression_parsing import ExprMetadata +from narwhals._expression_parsing import extract_compliant from narwhals.expr import Expr from narwhals.utils import flatten @@ -13,12 +15,7 @@ class Selector(Expr): def _to_expr(self: Self) -> Expr: - return Expr( - to_compliant_expr=self._to_compliant_expr, - is_order_dependent=self._is_order_dependent, - changes_length=self._changes_length, - aggregates=self._aggregates, - ) + return Expr(self._to_compliant_expr, self._metadata) def __add__(self: Self, other: Any) -> Expr: # type: ignore[override] if isinstance(other, Selector): @@ -35,8 +32,44 @@ def __rand__(self: Self, other: Any) -> NoReturn: def __ror__(self: Self, other: Any) -> NoReturn: raise NotImplementedError + def __and__(self, other: Selector | Any) -> Selector: + return Selector( + lambda plx: self._to_compliant_expr(plx) & extract_compliant(plx, other), + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), + ) + + def __or__(self, other: Selector | Any) -> Selector: + return Selector( + lambda plx: self._to_compliant_expr(plx) | extract_compliant(plx, other), + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), + ) + + def __sub__(self, other: Selector | Any) -> Selector: + return Selector( + lambda plx: self._to_compliant_expr(plx) - extract_compliant(plx, other), + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), + ) + + def __invert__(self: Self) -> Selector: + return all() - self + -def by_dtype(*dtypes: Any) -> Expr: +def by_dtype(*dtypes: Any) -> Selector: """Select columns based on their dtype. Arguments: @@ -81,13 +114,16 @@ def by_dtype(*dtypes: Any) -> Expr: """ return Selector( lambda plx: plx.selectors.by_dtype(flatten(dtypes)), - is_order_dependent=False, - changes_length=False, - aggregates=False, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), ) -def numeric() -> Expr: +def numeric() -> Selector: """Select numeric columns. Returns: @@ -129,13 +165,16 @@ def numeric() -> Expr: """ return Selector( lambda plx: plx.selectors.numeric(), - is_order_dependent=False, - changes_length=False, - aggregates=False, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), ) -def boolean() -> Expr: +def boolean() -> Selector: """Select boolean columns. Returns: @@ -177,13 +216,16 @@ def boolean() -> Expr: """ return Selector( lambda plx: plx.selectors.boolean(), - is_order_dependent=False, - changes_length=False, - aggregates=False, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), ) -def string() -> Expr: +def string() -> Selector: """Select string columns. Returns: @@ -225,13 +267,16 @@ def string() -> Expr: """ return Selector( lambda plx: plx.selectors.string(), - is_order_dependent=False, - changes_length=False, - aggregates=False, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), ) -def categorical() -> Expr: +def categorical() -> Selector: """Select categorical columns. Returns: @@ -273,13 +318,16 @@ def categorical() -> Expr: """ return Selector( lambda plx: plx.selectors.categorical(), - is_order_dependent=False, - changes_length=False, - aggregates=False, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), ) -def all() -> Expr: +def all() -> Selector: """Select all columns. Returns: @@ -321,9 +369,12 @@ def all() -> Expr: """ return Selector( lambda plx: plx.selectors.all(), - is_order_dependent=False, - changes_length=False, - aggregates=False, + ExprMetadata( + is_order_dependent=False, + changes_length=False, + aggregates=False, + is_multi_output=True, + ), ) diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index df443230c..b952bb0cd 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -15,6 +15,7 @@ from narwhals import dependencies from narwhals import exceptions from narwhals import selectors +from narwhals._expression_parsing import ExprMetadata from narwhals.dataframe import DataFrame as NwDataFrame from narwhals.dataframe import LazyFrame as NwLazyFrame from narwhals.dependencies import get_polars @@ -887,9 +888,9 @@ def head(self: Self, n: int = 10) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).head(n), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata( + {**self._metadata, "changes_length": True, "is_order_dependent": True} + ), ) def tail(self: Self, n: int = 10) -> Self: @@ -903,9 +904,9 @@ def tail(self: Self, n: int = 10) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).tail(n), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata( + {**self._metadata, "changes_length": True, "is_order_dependent": True} + ), ) def gather_every(self: Self, n: int, offset: int = 0) -> Self: @@ -920,9 +921,9 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).gather_every(n=n, offset=offset), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata( + {**self._metadata, "changes_length": True, "is_order_dependent": True} + ), ) def unique(self: Self, *, maintain_order: bool | None = None) -> Self: @@ -944,9 +945,7 @@ def unique(self: Self, *, maintain_order: bool | None = None) -> Self: warn(message=msg, category=UserWarning, stacklevel=find_stacklevel()) return self.__class__( lambda plx: self._to_compliant_expr(plx).unique(), - self._is_order_dependent, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata({**self._metadata, "changes_length": True}), ) def sort(self: Self, *, descending: bool = False, nulls_last: bool = False) -> Self: @@ -963,9 +962,7 @@ def sort(self: Self, *, descending: bool = False, nulls_last: bool = False) -> S lambda plx: self._to_compliant_expr(plx).sort( descending=descending, nulls_last=nulls_last ), - is_order_dependent=True, - changes_length=self._changes_length, - aggregates=self._aggregates, + metadata=ExprMetadata({**self._metadata, "is_order_dependent": True}), ) def arg_true(self: Self) -> Self: @@ -976,9 +973,9 @@ def arg_true(self: Self) -> Self: """ return self.__class__( lambda plx: self._to_compliant_expr(plx).arg_true(), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata( + {**self._metadata, "changes_length": True, "is_order_dependent": True} + ), ) def sample( @@ -1012,9 +1009,9 @@ def sample( lambda plx: self._to_compliant_expr(plx).sample( n, fraction=fraction, with_replacement=with_replacement, seed=seed ), - is_order_dependent=True, - changes_length=True, - aggregates=self._aggregates, + ExprMetadata( + {**self._metadata, "changes_length": True, "is_order_dependent": True} + ), ) @@ -1059,12 +1056,7 @@ def _stableify( level=obj._level, ) if isinstance(obj, NwExpr): - return Expr( - obj._to_compliant_expr, - is_order_dependent=obj._is_order_dependent, - changes_length=obj._changes_length, - aggregates=obj._aggregates, - ) + return Expr(obj._to_compliant_expr, obj._metadata) return obj @@ -2022,12 +2014,7 @@ def then(self: Self, value: Any) -> Then: class Then(NwThen, Expr): @classmethod def from_then(cls: type, then: NwThen) -> Then: - return cls( # type: ignore[no-any-return] - then._to_compliant_expr, - is_order_dependent=then._is_order_dependent, - changes_length=then._changes_length, - aggregates=then._aggregates, - ) + return cls(then._to_compliant_expr, then._metadata) # type: ignore[no-any-return] def otherwise(self: Self, value: Any) -> Expr: return _stableify(super().otherwise(value)) diff --git a/tests/expr_and_series/double_selected_test.py b/tests/expr_and_series/double_selected_test.py index a99c90163..862c1e0d3 100644 --- a/tests/expr_and_series/double_selected_test.py +++ b/tests/expr_and_series/double_selected_test.py @@ -1,6 +1,9 @@ from __future__ import annotations +import pytest + import narwhals.stable.v1 as nw +from narwhals.exceptions import MultiOutputExprError from tests.utils import Constructor from tests.utils import assert_equal_data @@ -17,6 +20,5 @@ def test_double_selected(constructor: Constructor) -> None: expected = {"z": [7, 8, 9], "a": [2, 6, 4], "b": [8, 8, 12]} assert_equal_data(result, expected) - result = df.select("a").select(nw.col("a") + nw.all()) - expected = {"a": [2, 6, 4]} - assert_equal_data(result, expected) + with pytest.raises(MultiOutputExprError): + df.select("a").select(nw.col("a") + nw.all()) diff --git a/utils/check_api_reference.py b/utils/check_api_reference.py index e0ebf97a9..743bf36bf 100644 --- a/utils/check_api_reference.py +++ b/utils/check_api_reference.py @@ -6,6 +6,7 @@ import polars as pl import narwhals as nw +from narwhals._expression_parsing import ExprMetadata from narwhals.utils import remove_prefix from narwhals.utils import remove_suffix @@ -44,6 +45,12 @@ "OrderedDict", "Mapping", } +PLACEHOLDER_EXPR_METADATA = ExprMetadata( + is_order_dependent=False, + aggregates=False, + changes_length=False, + is_multi_output=False, +) files = {remove_suffix(i, ".py") for i in os.listdir("narwhals")} @@ -161,9 +168,7 @@ # Expr methods expr_methods = [ i - for i in nw.Expr( - lambda: 0, is_order_dependent=False, changes_length=False, aggregates=False - ).__dir__() + for i in nw.Expr(lambda: 0, PLACEHOLDER_EXPR_METADATA).__dir__() if not i[0].isupper() and i[0] != "_" ] with open("docs/api-reference/expr.md") as fd: @@ -187,12 +192,7 @@ expr_methods = [ i for i in getattr( - nw.Expr( - lambda: 0, - is_order_dependent=False, - changes_length=False, - aggregates=False, - ), + nw.Expr(lambda: 0, PLACEHOLDER_EXPR_METADATA), namespace, ).__dir__() if not i[0].isupper() and i[0] != "_" @@ -236,9 +236,7 @@ # Check Expr vs Series expr = [ i - for i in nw.Expr( - lambda: 0, is_order_dependent=False, changes_length=False, aggregates=False - ).__dir__() + for i in nw.Expr(lambda: 0, PLACEHOLDER_EXPR_METADATA).__dir__() if not i[0].isupper() and i[0] != "_" ] series = [ @@ -260,12 +258,7 @@ expr_internal = [ i for i in getattr( - nw.Expr( - lambda: 0, - is_order_dependent=False, - changes_length=False, - aggregates=False, - ), + nw.Expr(lambda: 0, PLACEHOLDER_EXPR_METADATA), namespace, ).__dir__() if not i[0].isupper() and i[0] != "_"