From 70ee34045798f3ec42ff5903de2e1ed7649f4265 Mon Sep 17 00:00:00 2001 From: Loic Diridollou Date: Thu, 21 Nov 2024 18:14:34 -0500 Subject: [PATCH] GH1045 Split overload of groupby on as_index for all cases (#1046) * GH1045 Split overload of groupby on as_index for all cases * GH1045 PR Feedback --- pandas-stubs/core/frame.pyi | 100 +++++++++++++++++++++++++++++++----- tests/test_frame.py | 41 +++++++++++++-- 2 files changed, 123 insertions(+), 18 deletions(-) diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 295aad8b..b574fb21 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -1112,7 +1112,7 @@ class DataFrame(NDFrame, OpsMixin): dropna: _bool = ..., ) -> DataFrameGroupBy[Timestamp, Literal[True]]: ... @overload - def groupby( + def groupby( # pyright: ignore reportOverlappingOverload self, by: DatetimeIndex, axis: AxisIndex | NoDefault = ..., @@ -1124,77 +1124,149 @@ class DataFrame(NDFrame, OpsMixin): dropna: _bool = ..., ) -> DataFrameGroupBy[Timestamp, Literal[False]]: ... @overload - def groupby( + def groupby( # pyright: ignore reportOverlappingOverload self, by: TimedeltaIndex, axis: AxisIndex | NoDefault = ..., level: IndexLabel | None = ..., - as_index: _bool = ..., + as_index: Literal[True] = True, sort: _bool = ..., group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Timedelta, bool]: ... + ) -> DataFrameGroupBy[Timedelta, Literal[True]]: ... @overload def groupby( + self, + by: TimedeltaIndex, + axis: AxisIndex | NoDefault = ..., + level: IndexLabel | None = ..., + as_index: Literal[False] = ..., + sort: _bool = ..., + group_keys: _bool = ..., + observed: _bool | NoDefault = ..., + dropna: _bool = ..., + ) -> DataFrameGroupBy[Timedelta, Literal[False]]: ... + @overload + def groupby( # pyright: ignore reportOverlappingOverload self, by: PeriodIndex, axis: AxisIndex | NoDefault = ..., level: IndexLabel | None = ..., - as_index: _bool = ..., + as_index: Literal[True] = True, sort: _bool = ..., group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Period, bool]: ... + ) -> DataFrameGroupBy[Period, Literal[True]]: ... @overload def groupby( + self, + by: PeriodIndex, + axis: AxisIndex | NoDefault = ..., + level: IndexLabel | None = ..., + as_index: Literal[False] = ..., + sort: _bool = ..., + group_keys: _bool = ..., + observed: _bool | NoDefault = ..., + dropna: _bool = ..., + ) -> DataFrameGroupBy[Period, Literal[False]]: ... + @overload + def groupby( # pyright: ignore reportOverlappingOverload self, by: IntervalIndex[IntervalT], axis: AxisIndex | NoDefault = ..., level: IndexLabel | None = ..., - as_index: _bool = ..., + as_index: Literal[True] = True, sort: _bool = ..., group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[IntervalT, bool]: ... + ) -> DataFrameGroupBy[IntervalT, Literal[True]]: ... @overload def groupby( + self, + by: IntervalIndex[IntervalT], + axis: AxisIndex | NoDefault = ..., + level: IndexLabel | None = ..., + as_index: Literal[False] = ..., + sort: _bool = ..., + group_keys: _bool = ..., + observed: _bool | NoDefault = ..., + dropna: _bool = ..., + ) -> DataFrameGroupBy[IntervalT, Literal[False]]: ... + @overload + def groupby( # type: ignore[overload-overlap] # pyright: ignore reportOverlappingOverload self, by: MultiIndex | GroupByObjectNonScalar | None = ..., axis: AxisIndex | NoDefault = ..., level: IndexLabel | None = ..., - as_index: _bool = ..., + as_index: Literal[True] = True, sort: _bool = ..., group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[tuple, bool]: ... + ) -> DataFrameGroupBy[tuple, Literal[True]]: ... + @overload + def groupby( # type: ignore[overload-overlap] + self, + by: MultiIndex | GroupByObjectNonScalar | None = ..., + axis: AxisIndex | NoDefault = ..., + level: IndexLabel | None = ..., + as_index: Literal[False] = ..., + sort: _bool = ..., + group_keys: _bool = ..., + observed: _bool | NoDefault = ..., + dropna: _bool = ..., + ) -> DataFrameGroupBy[tuple, Literal[False]]: ... + @overload + def groupby( # pyright: ignore reportOverlappingOverload + self, + by: Series[SeriesByT], + axis: AxisIndex | NoDefault = ..., + level: IndexLabel | None = ..., + as_index: Literal[True] = True, + sort: _bool = ..., + group_keys: _bool = ..., + observed: _bool | NoDefault = ..., + dropna: _bool = ..., + ) -> DataFrameGroupBy[SeriesByT, Literal[True]]: ... @overload def groupby( self, by: Series[SeriesByT], axis: AxisIndex | NoDefault = ..., level: IndexLabel | None = ..., - as_index: _bool = ..., + as_index: Literal[False] = ..., + sort: _bool = ..., + group_keys: _bool = ..., + observed: _bool | NoDefault = ..., + dropna: _bool = ..., + ) -> DataFrameGroupBy[SeriesByT, Literal[False]]: ... + @overload + def groupby( + self, + by: CategoricalIndex | Index | Series, + axis: AxisIndex | NoDefault = ..., + level: IndexLabel | None = ..., + as_index: Literal[True] = True, sort: _bool = ..., group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[SeriesByT, bool]: ... + ) -> DataFrameGroupBy[Any, Literal[True]]: ... @overload def groupby( self, by: CategoricalIndex | Index | Series, axis: AxisIndex | NoDefault = ..., level: IndexLabel | None = ..., - as_index: _bool = ..., + as_index: Literal[False] = ..., sort: _bool = ..., group_keys: _bool = ..., observed: _bool | NoDefault = ..., dropna: _bool = ..., - ) -> DataFrameGroupBy[Any, bool]: ... + ) -> DataFrameGroupBy[Any, Literal[False]]: ... def pivot( self, *, diff --git a/tests/test_frame.py b/tests/test_frame.py index f67294d5..1d7e0468 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -504,8 +504,8 @@ def test_types_mean() -> None: s2: pd.Series = df.mean(axis=0) df2: pd.DataFrame = df.groupby(level=0).mean() if TYPE_CHECKING_INVALID_USAGE: - df3: pd.DataFrame = df.groupby(axis=1, level=0).mean() # type: ignore[call-overload] # pyright: ignore[reportArgumentType] - df4: pd.DataFrame = df.groupby(axis=1, level=0, dropna=True).mean() # type: ignore[call-overload] # pyright: ignore[reportArgumentType] + df3: pd.DataFrame = df.groupby(axis=1, level=0).mean() # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] + df4: pd.DataFrame = df.groupby(axis=1, level=0, dropna=True).mean() # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] s3: pd.Series = df.mean(axis=1, skipna=True, numeric_only=False) @@ -515,8 +515,8 @@ def test_types_median() -> None: s2: pd.Series = df.median(axis=0) df2: pd.DataFrame = df.groupby(level=0).median() if TYPE_CHECKING_INVALID_USAGE: - df3: pd.DataFrame = df.groupby(axis=1, level=0).median() # type: ignore[call-overload] # pyright: ignore[reportArgumentType] - df4: pd.DataFrame = df.groupby(axis=1, level=0, dropna=True).median() # type: ignore[call-overload] # pyright: ignore[reportArgumentType] + df3: pd.DataFrame = df.groupby(axis=1, level=0).median() # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] + df4: pd.DataFrame = df.groupby(axis=1, level=0, dropna=True).median() # type: ignore[call-overload] # pyright: ignore[reportArgumentType, reportCallIssue] s3: pd.Series = df.median(axis=1, skipna=True, numeric_only=False) @@ -1064,6 +1064,39 @@ def test_types_groupby_as_index() -> None: ), pd.Series, ) + check( + assert_type( + df.groupby("a").size(), + "pd.Series[int]", + ), + pd.Series, + ) + + +def test_types_groupby_as_index_list() -> None: + """Test type of groupby.size method depending on list of grouper GH1045.""" + df = pd.DataFrame({"a": [1, 1, 2], "b": [2, 3, 2]}) + check( + assert_type( + df.groupby(["a", "b"], as_index=False).size(), + pd.DataFrame, + ), + pd.DataFrame, + ) + check( + assert_type( + df.groupby(["a", "b"], as_index=True).size(), + "pd.Series[int]", + ), + pd.Series, + ) + check( + assert_type( + df.groupby(["a", "b"]).size(), + "pd.Series[int]", + ), + pd.Series, + ) def test_types_groupby_as_index_value_counts() -> None: