diff --git a/py-polars/polars/io/database/_executor.py b/py-polars/polars/io/database/_executor.py index 13e5b2cbe037..ef044d70d139 100644 --- a/py-polars/polars/io/database/_executor.py +++ b/py-polars/polars/io/database/_executor.py @@ -162,14 +162,14 @@ def _fetch_arrow( fetch_method = driver_properties["fetch_all"] yield getattr(self.result, fetch_method)() else: - size = batch_size if driver_properties["exact_batch_size"] else None + size = [batch_size] if driver_properties["exact_batch_size"] else [] repeat_batch_calls = driver_properties["repeat_batch_calls"] fetchmany_arrow = getattr(self.result, fetch_batches) if not repeat_batch_calls: - yield from fetchmany_arrow(size) + yield from fetchmany_arrow(*size) else: while True: - arrow = fetchmany_arrow(size) + arrow = fetchmany_arrow(*size) if not arrow: break yield arrow @@ -213,6 +213,13 @@ def _from_arrow( if re.match(f"^{driver}$", self.driver_name): if ver := driver_properties["minimum_version"]: self._check_module_version(self.driver_name, ver) + + if iter_batches and ( + driver_properties["exact_batch_size"] and not batch_size + ): + msg = f"Cannot set `iter_batches` for {self.driver_name} without also setting a non-zero `batch_size`" + raise ValueError(msg) # noqa: TRY301 + frames = ( self._apply_overrides(batch, (schema_overrides or {})) if isinstance(batch, DataFrame) @@ -247,6 +254,12 @@ def _from_rows( """Return resultset data row-wise for frame init.""" from polars import DataFrame + if iter_batches and not batch_size: + msg = ( + "Cannot set `iter_batches` without also setting a non-zero `batch_size`" + ) + raise ValueError(msg) + if is_async := isinstance(original_result := self.result, Coroutine): self.result = _run_async(self.result) try: @@ -506,11 +519,6 @@ def to_polars( if self.result is None: msg = "Cannot return a frame before executing a query" raise RuntimeError(msg) - elif iter_batches and not batch_size: - msg = ( - "Cannot set `iter_batches` without also setting a non-zero `batch_size`" - ) - raise ValueError(msg) can_close = self.can_close_cursor diff --git a/py-polars/polars/io/database/functions.py b/py-polars/polars/io/database/functions.py index cc033a3cfa5c..098dabfdd90f 100644 --- a/py-polars/polars/io/database/functions.py +++ b/py-polars/polars/io/database/functions.py @@ -98,16 +98,17 @@ def read_database( data returned by the query; this can be useful for processing large resultsets in a memory-efficient manner. If supported by the backend, this value is passed to the underlying query execution method (note that very low values will - typically result in poor performance as it will result in many round-trips to - the database as the data is returned). If the backend does not support changing + typically result in poor performance as it will cause many round-trips to the + database as the data is returned). If the backend does not support changing the batch size then a single DataFrame is yielded from the iterator. batch_size Indicate the size of each batch when `iter_batches` is True (note that you can still set this when `iter_batches` is False, in which case the resulting DataFrame is constructed internally using batched return before being returned - to you. Note that some backends may support batched operation but not allow for - an explicit size; in this case you will still receive batches, but their exact - size will be determined by the backend (so may not equal the value set here). + to you. Note that some backends (such as Snowflake) may support batch operation + but not allow for an explicit size to be set; in this case you will still + receive batches but their size is determined by the backend (in which case any + value set here will be ignored). schema_overrides A dictionary mapping column names to dtypes, used to override the schema inferred from the query cursor or given by the incoming Arrow data (depending @@ -242,7 +243,7 @@ def read_database( connection = ODBCCursorProxy(connection) elif "://" in connection: # otherwise looks like a mistaken call to read_database_uri - msg = "Use of string URI is invalid here; call `read_database_uri` instead" + msg = "use of string URI is invalid here; call `read_database_uri` instead" raise ValueError(msg) else: msg = "unable to identify string connection as valid ODBC (no driver)" diff --git a/py-polars/tests/unit/io/database/test_read.py b/py-polars/tests/unit/io/database/test_read.py index 52c4d5345b6a..3157e09cc3dd 100644 --- a/py-polars/tests/unit/io/database/test_read.py +++ b/py-polars/tests/unit/io/database/test_read.py @@ -46,12 +46,14 @@ def __init__( self, driver: str, batch_size: int | None, + exact_batch_size: bool, test_data: pa.Table, repeat_batch_calls: bool, ) -> None: self.__class__.__module__ = driver self._cursor = MockCursor( repeat_batch_calls=repeat_batch_calls, + exact_batch_size=exact_batch_size, batched=(batch_size is not None), test_data=test_data, ) @@ -69,10 +71,17 @@ class MockCursor: def __init__( self, batched: bool, + exact_batch_size: bool, test_data: pa.Table, repeat_batch_calls: bool, ) -> None: - self.resultset = MockResultSet(test_data, batched, repeat_batch_calls) + self.resultset = MockResultSet( + test_data=test_data, + batched=batched, + exact_batch_size=exact_batch_size, + repeat_batch_calls=repeat_batch_calls, + ) + self.exact_batch_size = exact_batch_size self.called: list[str] = [] self.batched = batched self.n_calls = 1 @@ -94,14 +103,21 @@ class MockResultSet: """Mock resultset class for databases we can't test in CI.""" def __init__( - self, test_data: pa.Table, batched: bool, repeat_batch_calls: bool = False + self, + test_data: pa.Table, + batched: bool, + exact_batch_size: bool, + repeat_batch_calls: bool = False, ): self.test_data = test_data self.repeat_batched_calls = repeat_batch_calls + self.exact_batch_size = exact_batch_size self.batched = batched self.n_calls = 1 def __call__(self, *args: Any, **kwargs: Any) -> Any: + if not self.exact_batch_size: + assert len(args) == 0 if self.repeat_batched_calls: res = self.test_data[: None if self.n_calls else 0] self.n_calls -= 1 @@ -478,13 +494,17 @@ def test_read_database_mocked( # since we don't have access to snowflake/databricks/etc from CI we # mock them so we can check that we're calling the expected methods arrow = pl.DataFrame({"x": [1, 2, 3], "y": ["aa", "bb", "cc"]}).to_arrow() + + reg = ARROW_DRIVER_REGISTRY.get(driver, {}) # type: ignore[var-annotated] + exact_batch_size = reg.get("exact_batch_size", False) + repeat_batch_calls = reg.get("repeat_batch_calls", False) + mc = MockConnection( driver, batch_size, test_data=arrow, - repeat_batch_calls=ARROW_DRIVER_REGISTRY.get(driver, {}).get( # type: ignore[call-overload] - "repeat_batch_calls", False - ), + repeat_batch_calls=repeat_batch_calls, + exact_batch_size=exact_batch_size, # type: ignore[arg-type] ) res = pl.read_database( query="SELECT * FROM test_data",