Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Series.hist #1859

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api-reference/series.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
- filter
- gather_every
- head
- hist
- implementation
- is_between
- is_duplicated
Expand Down
112 changes: 112 additions & 0 deletions narwhals/_arrow/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from narwhals._arrow.utils import narwhals_to_native_dtype
from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._arrow.utils import pad_series
from narwhals.exceptions import InvalidOperationError
from narwhals.typing import CompliantSeries
from narwhals.utils import Implementation
from narwhals.utils import generate_temporary_column_name
Expand Down Expand Up @@ -1004,6 +1005,117 @@ def rank(
result = pc.if_else(null_mask, pa.scalar(None), rank)
return self._from_native_series(result)

def hist(
self: Self,
bins: list[float | int] | None,
*,
bin_count: int | None,
include_category: bool,
include_breakpoint: bool,
) -> ArrowDataFrame:
if self._backend_version < (13,):
msg = f"`Series.hist` requires PyArrow>=13.0.0, found PyArrow version: {self._backend_version}"
raise NotImplementedError(msg)
import numpy as np # ignore-banned-import

from narwhals._arrow.dataframe import ArrowDataFrame

def _hist_from_bin_count(
bin_count: int,
) -> tuple[Sequence[int], Sequence[int | float], Sequence[int | float]]:
d = pc.min_max(self._native_series)
lower, upper = d["min"], d["max"]
if lower == upper:
lower -= 0.001 * abs(lower) if lower != 0 else 0.001
upper += 0.001 * abs(upper) if upper != 0 else 0.001

range_ = pc.subtract(upper, lower)
width = pc.divide(range_.cast("float"), float(bin_count))
bin_proportions = pc.divide(pc.subtract(self._native_series, lower), width)
bin_indices = pc.floor(bin_proportions)

bin_indices = pc.if_else( # shift bins so they are right-closed
pc.and_(
pc.equal(bin_indices, bin_proportions),
pc.greater(bin_indices, 0),
),
pc.subtract(bin_indices, 1),
bin_indices,
)
counts = ( # count bin id occurrences
pa.Table.from_arrays(
pc.value_counts(bin_indices)
.cast(pa.struct({"values": pa.int64(), "counts": pa.int64()}))
.flatten(),
names=["values", "counts"],
)
.join( # align bin ids to all possible bin ids (populate in missing bins)
pa.Table.from_arrays(
[np.arange(bin_count, dtype="int64")], ["values"]
),
keys="values",
join_type="right outer",
)
.sort_by("values")
)
counts = counts.set_column( # empty bin intervals should have a 0 count
0, "counts", pc.coalesce(counts.column("counts"), 0)
)

# extract left/right side of the intervals
bin_left = pc.multiply(counts.column("values"), width)
bin_right = pc.add(bin_left, width)
bin_left = pa.chunked_array(
[ # pad lowest bin by 1% of range
[pc.subtract(bin_left[0], pc.multiply(range_.cast("float"), 0.001))],
bin_left[1:], # pyarrow==0.11.0 needs to infer
]
)
counts = counts.column("counts")
return counts, bin_left, bin_right

def _hist_from_bins(
bins: Sequence[int | float],
) -> tuple[Sequence[int], Sequence[int | float], Sequence[int | float]]:
bin_indices = np.searchsorted(bins, self._native_series, side="left")
obs_cats, obs_counts = np.unique(bin_indices, return_counts=True)
obj_cats = np.arange(1, len(bins))
counts = np.zeros_like(obj_cats)
counts[np.isin(obj_cats, obs_cats)] = obs_counts[np.isin(obs_cats, obj_cats)]

bin_right = bins[1:]
bin_left = bins[:-1]
return counts, bin_left, bin_right

if bins is not None:
counts, bin_left, bin_right = _hist_from_bins(bins)

elif bin_count is not None:
if bin_count == 0:
counts, bin_left, bin_right = [], [], []
else:
counts, bin_left, bin_right = _hist_from_bin_count(bin_count)

else: # pragma: no cover
# caller guarantees that either bins or bin_count is specified
msg = "must provide one of `bin_count` or `bins`"
raise InvalidOperationError(msg)

data: dict[str, Sequence[int | float | str]] = {}
if include_breakpoint:
data["breakpoint"] = bin_right
if include_category:
data["category"] = [
f"({left}, {right}]" for left, right in zip(bin_left, bin_right)
]
data["count"] = counts

return ArrowDataFrame(
pa.Table.from_pydict(data),
backend_version=self._backend_version,
version=self._version,
)

def __iter__(self: Self) -> Iterator[Any]:
yield from (
maybe_extract_py_scalar(x, return_py_scalar=True)
Expand Down
53 changes: 53 additions & 0 deletions narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,59 @@ def rank(

return self._from_native_series(ranked_series)

def hist(
self: Self,
bins: list[float | int] | None,
*,
bin_count: int | None,
include_category: bool,
include_breakpoint: bool,
) -> PandasLikeDataFrame:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame

ns = self.__native_namespace__()
data: dict[str, Sequence[int | float | str]]

if bin_count == 0:
data = {}
if include_breakpoint:
data["breakpoint"] = []
if include_category:
data["category"] = []
data["count"] = []

return PandasLikeDataFrame(
ns.DataFrame(data),
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
)

# pandas (2.2.*) .value_counts(bins=int) adjusts the lowest bin is resulting in improper counts.
# pandas (2.2.*) .value_counts(bins=[...]) adjusts the lowest bin which should not happen since
# the bins were explicitly passed in.
categories = ns.cut(
self._native_series, bins=bins if bin_count is None else bin_count
)
# modin (0.32.0) .value_counts(...) silently drops bins with empty observations, .reindex
# is necessary to restore these bins.
result = categories.value_counts(dropna=True, sort=False).reindex(
categories.cat.categories, fill_value=0
)
data = {}
if include_breakpoint:
data["breakpoint"] = result.index.right
if include_category:
data["category"] = ns.Categorical(result.index.astype(str))
data["count"] = result.reset_index(drop=True)

return PandasLikeDataFrame(
ns.DataFrame(data),
implementation=self._implementation,
backend_version=self._backend_version,
version=self._version,
)

@property
def str(self: Self) -> PandasLikeSeriesStringNamespace:
return PandasLikeSeriesStringNamespace(self)
Expand Down
66 changes: 66 additions & 0 deletions narwhals/_polars/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence
from typing import Union
from typing import cast
from typing import overload

import polars as pl
Expand Down Expand Up @@ -476,6 +478,70 @@ def __contains__(self: Self, other: Any) -> bool:
msg = f"Unable to compare other of type {type(other)} with series of type {self.dtype}."
raise InvalidOperationError(msg) from exc

def hist(
self: Self,
bins: list[float | int] | None,
*,
bin_count: int | None,
include_category: bool,
include_breakpoint: bool,
) -> PolarsDataFrame:
from narwhals._polars.dataframe import PolarsDataFrame

# polars<1.15 returned bins -inf to inf OR fails in these conditions
if (self._backend_version < (1, 15)) and (
(bins is not None and len(bins) == 0) or (bin_count == 0)
): # pragma: no cover
data: list[pl.Series] = []
if include_breakpoint:
data.append(pl.Series("breakpoint", [], dtype=pl.Float64))
if include_category:
data.append(pl.Series("category", [], dtype=pl.Categorical))
data.append(pl.Series("count", [], dtype=pl.UInt32))
return PolarsDataFrame(
pl.DataFrame(data),
backend_version=self._backend_version,
version=self._version,
)

# polars <1.5 with bin_count=...
# returns bins that range from -inf to +inf and has bin_count + 1 bins.
# for compat: convert `bin_count=` call to `bins=`
if (self._backend_version < (1, 5)) and (
bin_count is not None
): # pragma: no cover
lower = cast(Union[int, float], self._native_series.min())
upper = cast(Union[int, float], self._native_series.max())
if lower == upper:
lower -= 0.001 * abs(lower) if lower != 0 else 0.001
upper += 0.001 * abs(upper) if upper != 0 else 0.001
width = (upper - lower) / bin_count

bins = (pl.int_range(0, bin_count + 1, eager=True) * width).to_list()
bins[0] -= (upper - lower) * 0.001
bin_count = None

df = self._native_series.hist(
bins=bins,
bin_count=bin_count,
include_category=include_category,
include_breakpoint=include_breakpoint,
)
if not include_category and not include_breakpoint:
df.columns = ["count"]

# polars<1.15 implicitly adds -inf and inf to either end of bins
if self._backend_version < (1, 15) and bins is not None: # pragma: no cover
r = pl.int_range(0, len(df))
df = df.filter((r > 0) & (r < len(df) - 1))

if self._backend_version < (1, 0) and include_breakpoint:
df = df.rename({"break_point": "breakpoint"})

return PolarsDataFrame(
df, backend_version=self._backend_version, version=self._version
)

def to_polars(self: Self) -> pl.Series:
return self._native_series

Expand Down
4 changes: 4 additions & 0 deletions narwhals/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def from_missing_and_available_column_names(
return ColumnNotFoundError(message)


class ComputeError(Exception):
"""Exception raised when the underlying computation could not be evaluated."""


class ShapeError(Exception):
"""Exception raised when trying to perform operations on data structures with incompatible shapes."""

Expand Down
46 changes: 46 additions & 0 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from narwhals.dependencies import is_numpy_scalar
from narwhals.dtypes import _validate_dtype
from narwhals.exceptions import ComputeError
from narwhals.series_cat import SeriesCatNamespace
from narwhals.series_dt import SeriesDateTimeNamespace
from narwhals.series_list import SeriesListNamespace
Expand Down Expand Up @@ -4900,6 +4901,51 @@ def rank(
self._compliant_series.rank(method=method, descending=descending)
)

def hist(
self: Self,
bins: list[float | int] | None = None,
*,
bin_count: int | None = None,
include_category: bool = True,
include_breakpoint: bool = True,
) -> DataFrame[Any]:
"""Bin values into buckets and count their occurrences.

!!! warning
This functionality is considered **unstable**. It may be changed at any point
without it being considered a breaking change.

Arguments:
bins: A monotonically increasing sequence of values.
bin_count: If no bins provided, this will be used to determine the distance of the bins.
include_category: Include a column that indicates the upper value of each bin.
include_breakpoint: Include a column that shows the intervals as categories.

Returns:
A new DataFrame containing the counts of values that occur within each passed bin.
"""
if bins is not None and bin_count is not None:
msg = "can only provide one of `bin_count` or `bins`"
raise ComputeError(msg)
if bins is None and bin_count is None:
bin_count = 10 # polars (v1.20) sets bin=10 if neither are provided.

if bins is not None:
for i in range(1, len(bins)):
if bins[i - 1] >= bins[i]:
msg = "bins must increase monotonically"
raise ComputeError(msg)

return self._dataframe(
self._compliant_series.hist(
bins=bins,
bin_count=bin_count,
include_category=include_category,
include_breakpoint=include_breakpoint,
),
level=self._level,
)

@property
def str(self: Self) -> SeriesStringNamespace[Self]:
return SeriesStringNamespace(self)
Expand Down
38 changes: 38 additions & 0 deletions narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,6 +629,44 @@ def rolling_std(
ddof=ddof,
)

def hist(
self: Self,
bins: list[float | int] | None = None,
*,
bin_count: int | None = None,
include_category: bool = True,
include_breakpoint: bool = True,
) -> DataFrame[Any]:
"""Bin values into buckets and count their occurrences.

!!! warning
This functionality is considered **unstable**. It may be changed at any point
without it being considered a breaking change.

Arguments:
bins: A monotonically increasing sequence of values.
bin_count: If no bins provided, this will be used to determine the distance of the bins.
include_category: Include a column that indicates the upper value of each bin.
include_breakpoint: Include a column that shows the intervals as categories.

Returns:
A new DataFrame containing the counts of values that occur within each passed bin.
"""
from narwhals.exceptions import NarwhalsUnstableWarning
from narwhals.utils import find_stacklevel

msg = (
"`Series.hist` is being called from the stable API although considered "
"an unstable feature."
)
warn(message=msg, category=NarwhalsUnstableWarning, stacklevel=find_stacklevel())
return super().hist( # type: ignore[return-value]
bins=bins,
bin_count=bin_count,
include_category=include_category,
include_breakpoint=include_breakpoint,
)


class Expr(NwExpr):
def _l1_norm(self: Self) -> Self:
Expand Down
Loading
Loading