From 82e85e82f13dcd03ea96a96b0d0a1be702b9f710 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Thu, 3 Nov 2022 02:00:31 +0400 Subject: [PATCH] feat(python): support series init from generators --- py-polars/polars/internals/construction.py | 53 +++++++++++++++++++- py-polars/polars/internals/lazy_functions.py | 23 ++++++--- py-polars/polars/internals/series/series.py | 33 +++++++++++- py-polars/tests/unit/test_series.py | 45 ++++++++++++++--- 4 files changed, 137 insertions(+), 17 deletions(-) diff --git a/py-polars/polars/internals/construction.py b/py-polars/polars/internals/construction.py index 12caf86df267..40a9791b66e0 100644 --- a/py-polars/polars/internals/construction.py +++ b/py-polars/polars/internals/construction.py @@ -3,9 +3,17 @@ from contextlib import suppress from dataclasses import astuple, is_dataclass from datetime import date, datetime, time, timedelta -from itertools import zip_longest +from itertools import islice, zip_longest from sys import version_info -from typing import TYPE_CHECKING, Any, Iterable, Mapping, Sequence, get_type_hints +from typing import ( + TYPE_CHECKING, + Any, + Generator, + Iterable, + Mapping, + Sequence, + get_type_hints, +) from polars import internals as pli from polars.datatypes import ( @@ -178,6 +186,47 @@ def sequence_from_anyvalue_or_object(name: str, values: Sequence[Any]) -> PySeri return PySeries.new_object(name, values, False) +def iterable_to_pyseries( + name: str, + values: Iterable[Any], + dtype: PolarsDataType | None = None, + strict: bool = True, + dtype_if_empty: PolarsDataType | None = None, + chunk_size: int = 1_000_000, +) -> PySeries: + """Construct a PySeries from an iterable/generator.""" + if not isinstance(values, Generator): + values = iter(values) + + def to_series_chunk(values: list[Any], dtype: PolarsDataType | None) -> pli.Series: + return pli.Series( + name=name, + values=values, + dtype=dtype, + strict=strict, + dtype_if_empty=dtype_if_empty, + ) + + n_chunks = 0 + series: pli.Series = None # type: ignore[assignment] + while True: + slice_values = list(islice(values, chunk_size)) + if not slice_values: + break + schunk = to_series_chunk(slice_values, dtype) + if series is None: + series = schunk + dtype = series.dtype + else: + series.append(schunk, append_chunks=True) + n_chunks += 1 + + if n_chunks > 0: + series.rechunk(in_place=True) + + return series._s + + def sequence_to_pyseries( name: str, values: Sequence[Any], diff --git a/py-polars/polars/internals/lazy_functions.py b/py-polars/polars/internals/lazy_functions.py index ed62b9c80e78..c7430b7d50be 100644 --- a/py-polars/polars/internals/lazy_functions.py +++ b/py-polars/polars/internals/lazy_functions.py @@ -1340,6 +1340,7 @@ def arange( step: int = ..., *, eager: Literal[True], + dtype: PolarsDataType | None = ..., ) -> pli.Series: ... @@ -1351,6 +1352,7 @@ def arange( step: int = ..., *, eager: bool = False, + dtype: PolarsDataType | None = ..., ) -> pli.Expr | pli.Series: ... @@ -1361,13 +1363,13 @@ def arange( step: int = 1, *, eager: bool = False, + dtype: PolarsDataType | None = None, ) -> pli.Expr | pli.Series: """ - Create a range expression. + Create a range expression (or Series). - This can be used in a `select`, `with_column` etc. - - Be sure that the range size is equal to the DataFrame you are collecting. + This can be used in a `select`, `with_column` etc. Be sure that the resulting + range size is equal to the length of the DataFrame you are collecting. Examples -------- @@ -1383,18 +1385,25 @@ def arange( Step size of the range. eager If eager evaluation is `True`, a Series is returned instead of an Expr. + dtype + Apply an explicit integer dtype to the resulting expression (default is Int64). """ low = pli.expr_to_lit_or_expr(low, str_to_lit=False) high = pli.expr_to_lit_or_expr(high, str_to_lit=False) - if eager: + range_expr = pli.wrap_expr(pyarange(low._pyexpr, high._pyexpr, step)) + + if dtype is not None and dtype != Int64: + range_expr = range_expr.cast(dtype) + if not eager: + return range_expr + else: return ( pli.DataFrame() - .select(arange(low, high, step)) + .select(range_expr) .to_series() .rename("arange", in_place=True) ) - return pli.wrap_expr(pyarange(low._pyexpr, high._pyexpr, step)) def argsort_by( diff --git a/py-polars/polars/internals/series/series.py b/py-polars/polars/internals/series/series.py index d58e84ffca4b..9d747f304d90 100644 --- a/py-polars/polars/internals/series/series.py +++ b/py-polars/polars/internals/series/series.py @@ -3,7 +3,18 @@ import math import warnings from datetime import date, datetime, time, timedelta -from typing import TYPE_CHECKING, Any, Callable, NoReturn, Sequence, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generator, + Iterable, + Mapping, + NoReturn, + Sequence, + Union, + overload, +) from warnings import warn from polars import internals as pli @@ -46,6 +57,7 @@ from polars.dependencies import pyarrow as pa from polars.internals.construction import ( arrow_to_pyseries, + iterable_to_pyseries, numpy_to_pyseries, pandas_to_pyseries, sequence_to_pyseries, @@ -228,7 +240,13 @@ def __init__( elif isinstance(values, range): self._s = ( - pli.arange(values.start, values.stop, values.step, eager=True) + pli.arange( + low=values.start, + high=values.stop, + step=values.step, + eager=True, + dtype=dtype, + ) .rename(name, in_place=True) ._s ) @@ -238,6 +256,17 @@ def __init__( ) elif _PANDAS_TYPE(values) and isinstance(values, (pd.Series, pd.DatetimeIndex)): self._s = pandas_to_pyseries(name, values) + + elif isinstance(values, (Generator, Iterable)) and not isinstance( + values, Mapping + ): + self._s = iterable_to_pyseries( + name, + values, + dtype=dtype, + strict=strict, + dtype_if_empty=dtype_if_empty, + ) else: raise ValueError(f"Series constructor not called properly. Got {values}.") diff --git a/py-polars/tests/unit/test_series.py b/py-polars/tests/unit/test_series.py index b4b3eabe61c9..769c35cb0ee8 100644 --- a/py-polars/tests/unit/test_series.py +++ b/py-polars/tests/unit/test_series.py @@ -2,7 +2,7 @@ import math from datetime import date, datetime, time, timedelta -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, Iterator, cast import numpy as np import pandas as pd @@ -22,6 +22,7 @@ UInt32, UInt64, ) +from polars.internals.construction import iterable_to_pyseries from polars.internals.type_aliases import EpochTimeUnit from polars.testing import assert_frame_equal, assert_series_equal from polars.testing._private import verify_series_and_expr_api @@ -107,11 +108,6 @@ def test_init_inputs(monkeypatch: Any) -> None: with pytest.raises(OverflowError): pl.Series("bigint", [2**64]) - # numpy not available - monkeypatch.setattr(pl.internals.series.series, "_NUMPY_TYPE", lambda x: False) - with pytest.raises(ValueError): - pl.DataFrame(np.array([1, 2, 3]), columns=["a"]) - def test_init_dataclass_namedtuple() -> None: from dataclasses import dataclass @@ -1379,6 +1375,43 @@ def test_to_numpy(monkeypatch: Any) -> None: assert np_array_with_missing_values.flags.writeable == writable +def test_from_generator_or_iterable() -> None: + # iterable object + class Data: + def __init__(self, n: int): + self._n = n + + def __iter__(self) -> Iterator[int]: + yield from range(self._n) + + # generator function + def gen(n: int) -> Iterator[int]: + yield from range(n) + + expected = pl.Series("s", range(10)) + assert expected.dtype == pl.Int64 + + for generated_series in ( + pl.Series("s", values=gen(10)), + pl.Series("s", values=Data(10)), + pl.Series("s", values=(x for x in gen(10))), + ): + assert_series_equal(expected, generated_series) + + # test 'iterable_to_pyseries' directly to validate 'chunk_size' behaviour + ps1 = iterable_to_pyseries("s", gen(10), dtype=pl.UInt8) + ps2 = iterable_to_pyseries("s", gen(10), dtype=pl.UInt8, chunk_size=3) + ps3 = iterable_to_pyseries("s", Data(10), dtype=pl.UInt8, chunk_size=6) + + expected = pl.Series("s", range(10), dtype=pl.UInt8) + assert expected.dtype == pl.UInt8 + + for ps in (ps1, ps2, ps3): + generated_series = pl.Series("s") + generated_series._s = ps + assert_series_equal(expected, generated_series) + + def test_from_sequences(monkeypatch: Any) -> None: # test int, str, bool, flt values = [