diff --git a/python/python/lance/dataset.py b/python/python/lance/dataset.py index d55596e6b3..fd70f45d71 100644 --- a/python/python/lance/dataset.py +++ b/python/python/lance/dataset.py @@ -99,6 +99,7 @@ def scanner( limit: Optional[int] = None, offset: Optional[int] = None, nearest: Optional[dict] = None, + batch_size: Optional[int] = None, batch_readahead: Optional[int] = None, fragment_readahead: Optional[int] = None, scan_in_order: bool = True, @@ -131,7 +132,8 @@ def scanner( "nprobes": 1, "refine_factor": 1 } - + batch_size: int, default None + The number of rows to fetch per batch. batch_readahead: int, optional The number of batches to read ahead. fragment_readahead: int, optional @@ -172,6 +174,7 @@ def scanner( .limit(limit) .offset(offset) .nearest(**(nearest or {})) + .batch_size(batch_size) .batch_readahead(batch_readahead) .fragment_readahead(fragment_readahead) .scan_in_order(scan_in_order) @@ -684,11 +687,17 @@ def __init__(self, ds: LanceDataset): self._offset = None self._columns = None self._nearest = None + self._batch_size = None self._batch_readahead = None self._fragment_readahead = None self._scan_in_order = True self._fragments = None + def batch_size(self, batch_size: int) -> ScannerBuilder: + """Set batch size for Scanner""" + self._batch_size = batch_size + return self + def batch_readahead(self, nbatches: Optional[int] = None) -> ScannerBuilder: if nbatches is not None and int(nbatches) < 0: raise ValueError("batch_readahead must be non-negative") @@ -799,6 +808,7 @@ def to_scanner(self) -> LanceScanner: self._limit, self._offset, self._nearest, + self._batch_size, self._batch_readahead, self._fragment_readahead, self._scan_in_order, diff --git a/python/python/tests/test_dataset.py b/python/python/tests/test_dataset.py index 407e45ce5a..dfdbc7b9cb 100644 --- a/python/python/tests/test_dataset.py +++ b/python/python/tests/test_dataset.py @@ -502,3 +502,16 @@ def test_create_update_empty_dataset(tmp_path: Path, provide_pandas: bool): assert dataset.to_table() == pa.table( {"a": ["foo"], "b": [1], "c": [2.0]}, schema=expected_schema ) + + +def test_scan_with_batch_size(tmp_path: Path): + base_dir = tmp_path / "dataset" + df = pd.DataFrame({"a": range(10000), "b": range(10000)}) + dataset = lance.write_dataset(df, base_dir) + + batches = dataset.scanner(batch_size=16, scan_in_order=True).to_batches() + + for idx, batch in enumerate(batches): + assert batch.num_rows == 16 + df = batch.to_pandas() + assert df["a"].iloc[0] == idx * 16 diff --git a/python/src/dataset.rs b/python/src/dataset.rs index 259f60ab1e..2bf517e170 100644 --- a/python/src/dataset.rs +++ b/python/src/dataset.rs @@ -138,6 +138,7 @@ impl Dataset { limit: Option, offset: Option, nearest: Option<&PyDict>, + batch_size: Option, batch_readahead: Option, fragment_readahead: Option, scan_in_order: Option, @@ -159,6 +160,9 @@ impl Dataset { .limit(limit, offset) .map_err(|err| PyValueError::new_err(err.to_string()))?; + if let Some(batch_size) = batch_size { + scanner.batch_size(batch_size); + } if let Some(batch_readahead) = batch_readahead { scanner.batch_readahead(batch_readahead); }