From 2fdf08cf7dba8eb347100c6e8a1c65cd9c257d43 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Mon, 24 Jul 2023 10:20:27 -0700 Subject: [PATCH 1/6] pass batch size --- python/python/lance/dataset.py | 11 ++++++++++- python/src/dataset.rs | 4 ++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index d55596e6b3..ba1e355bfc 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -99,6 +99,7 @@ def scanner( limit: Optional[int] = None, offset: Optional[int] = None, nearest: Optional[dict] = None, + batch_size: Optional[int] = None, batch_readahead: Optional[int] = None, fragment_readahead: Optional[int] = None, scan_in_order: bool = True, @@ -131,7 +132,8 @@ def scanner( "nprobes": 1, "refine_factor": 1 } - + batch_size: int, default None + The number of rows to fetch per batch. batch_readahead: int, optional The number of batches to read ahead. fragment_readahead: int, optional @@ -684,11 +686,17 @@ def __init__(self, ds: LanceDataset): self._offset = None self._columns = None self._nearest = None + self._batch_size = None self._batch_readahead = None self._fragment_readahead = None self._scan_in_order = True self._fragments = None + def batch_size(self, batch_size: int) -> ScannerBuilder: + """Set batch size for Scanner""" + self._batch_size = batch_size + return self + def batch_readahead(self, nbatches: Optional[int] = None) -> ScannerBuilder: if nbatches is not None and int(nbatches) < 0: raise ValueError("batch_readahead must be non-negative") @@ -799,6 +807,7 @@ def to_scanner(self) -> LanceScanner: self._limit, self._offset, self._nearest, + self._batch_size, self._batch_readahead, self._fragment_readahead, self._scan_in_order, diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 259f60ab1e..2bf517e170 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -138,6 +138,7 @@ impl Dataset { limit: Option, offset: Option, nearest: Option<&PyDict>, + batch_size: Option, batch_readahead: Option, fragment_readahead: Option, scan_in_order: Option, @@ -159,6 +160,9 @@ impl Dataset { .limit(limit, offset) .map_err(|err| PyValueError::new_err(err.to_string()))?; + if let Some(batch_size) = batch_size { + scanner.batch_size(batch_size); + } if let Some(batch_readahead) = batch_readahead { scanner.batch_readahead(batch_readahead); } From 81798cff145a85ac9e540cab2045cff6569a0c47 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Mon, 24 Jul 2023 10:31:40 -0700 Subject: [PATCH 2/6] dataset --- python/python/lance/dataset.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index ba1e355bfc..9dba46ecce 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -23,7 +23,6 @@ import numpy as np import pyarrow as pa -import pyarrow.dataset from pyarrow import RecordBatch, Schema from pyarrow._compute import Expression From f1fe590df84ee7de6651ed07be21bc253e48a8d0 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Mon, 24 Jul 2023 10:38:10 -0700 Subject: [PATCH 3/6] fix pyarrow dataset --- python/python/lance/dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index 9dba46ecce..bc8b00d5c2 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -23,7 +23,7 @@ import numpy as np import pyarrow as pa -from pyarrow import RecordBatch, Schema +from pyarrow import RecordBatch, Schema, dataset from pyarrow._compute import Expression from .fragment import LanceFragment From 0e6286d9594c48f9e41ff4c486ecb93a93c0df66 Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Mon, 24 Jul 2023 10:45:13 -0700 Subject: [PATCH 4/6] fix lint --- python/python/lance/dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index bc8b00d5c2..ba1e355bfc 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -23,7 +23,8 @@ import numpy as np import pyarrow as pa -from pyarrow import RecordBatch, Schema, dataset +import pyarrow.dataset +from pyarrow import RecordBatch, Schema from pyarrow._compute import Expression from .fragment import LanceFragment From 2fbeeee6d4c0f77c7349f23c846f6a884556d99c Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Mon, 24 Jul 2023 10:54:09 -0700 Subject: [PATCH 5/6] pass batch size --- python/python/lance/dataset.py | 1 + python/python/tests/test_dataset.py | 11 +++++++++++ 2 files changed, 12 insertions(+) diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index ba1e355bfc..fd70f45d71 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -174,6 +174,7 @@ def scanner( .limit(limit) .offset(offset) .nearest(**(nearest or {})) + .batch_size(batch_size) .batch_readahead(batch_readahead) .fragment_readahead(fragment_readahead) .scan_in_order(scan_in_order) diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 407e45ce5a..22150e13a8 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -502,3 +502,14 @@ def test_create_update_empty_dataset(tmp_path: Path, provide_pandas: bool): assert dataset.to_table() == pa.table( {"a": ["foo"], "b": [1], "c": [2.0]}, schema=expected_schema ) + + +def test_scan_with_batch_size(tmp_path: Path): + base_dir = tmp_path / "dataset" + df = pd.DataFrame({"a": range(10000), "b": range(10000)}) + dataset = lance.write_dataset(df, base_dir) + + batches = dataset.scanner(batch_size=16).to_batches() + batch = next(batches) + + assert batch.num_rows == 16 From 1e67f6d4d5e7b5e3e060c1cbfd3997340f0c921d Mon Sep 17 00:00:00 2001 From: Lei Xu Date: Mon, 24 Jul 2023 10:58:42 -0700 Subject: [PATCH 6/6] fix test --- python/python/tests/test_dataset.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 22150e13a8..dfdbc7b9cb 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -509,7 +509,9 @@ def test_scan_with_batch_size(tmp_path: Path): df = pd.DataFrame({"a": range(10000), "b": range(10000)}) dataset = lance.write_dataset(df, base_dir) - batches = dataset.scanner(batch_size=16).to_batches() - batch = next(batches) + batches = dataset.scanner(batch_size=16, scan_in_order=True).to_batches() - assert batch.num_rows == 16 + for idx, batch in enumerate(batches): + assert batch.num_rows == 16 + df = batch.to_pandas() + assert df["a"].iloc[0] == idx * 16