Skip to content

Commit

Permalink
[Python] Set batch_size in Dataset::scanner() (#1088)
Browse files Browse the repository at this point in the history
  • Loading branch information
eddyxu authored Jul 24, 2023
1 parent ed2cb0d commit 656ace2
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 1 deletion.
12 changes: 11 additions & 1 deletion python/python/lance/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def scanner(
limit: Optional[int] = None,
offset: Optional[int] = None,
nearest: Optional[dict] = None,
batch_size: Optional[int] = None,
batch_readahead: Optional[int] = None,
fragment_readahead: Optional[int] = None,
scan_in_order: bool = True,
Expand Down Expand Up @@ -131,7 +132,8 @@ def scanner(
"nprobes": 1,
"refine_factor": 1
}
batch_size: int, default None
The number of rows to fetch per batch.
batch_readahead: int, optional
The number of batches to read ahead.
fragment_readahead: int, optional
Expand Down Expand Up @@ -172,6 +174,7 @@ def scanner(
.limit(limit)
.offset(offset)
.nearest(**(nearest or {}))
.batch_size(batch_size)
.batch_readahead(batch_readahead)
.fragment_readahead(fragment_readahead)
.scan_in_order(scan_in_order)
Expand Down Expand Up @@ -684,11 +687,17 @@ def __init__(self, ds: LanceDataset):
self._offset = None
self._columns = None
self._nearest = None
self._batch_size = None
self._batch_readahead = None
self._fragment_readahead = None
self._scan_in_order = True
self._fragments = None

def batch_size(self, batch_size: int) -> ScannerBuilder:
"""Set batch size for Scanner"""
self._batch_size = batch_size
return self

def batch_readahead(self, nbatches: Optional[int] = None) -> ScannerBuilder:
if nbatches is not None and int(nbatches) < 0:
raise ValueError("batch_readahead must be non-negative")
Expand Down Expand Up @@ -799,6 +808,7 @@ def to_scanner(self) -> LanceScanner:
self._limit,
self._offset,
self._nearest,
self._batch_size,
self._batch_readahead,
self._fragment_readahead,
self._scan_in_order,
Expand Down
13 changes: 13 additions & 0 deletions python/python/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,3 +502,16 @@ def test_create_update_empty_dataset(tmp_path: Path, provide_pandas: bool):
assert dataset.to_table() == pa.table(
{"a": ["foo"], "b": [1], "c": [2.0]}, schema=expected_schema
)


def test_scan_with_batch_size(tmp_path: Path):
base_dir = tmp_path / "dataset"
df = pd.DataFrame({"a": range(10000), "b": range(10000)})
dataset = lance.write_dataset(df, base_dir)

batches = dataset.scanner(batch_size=16, scan_in_order=True).to_batches()

for idx, batch in enumerate(batches):
assert batch.num_rows == 16
df = batch.to_pandas()
assert df["a"].iloc[0] == idx * 16
4 changes: 4 additions & 0 deletions python/src/dataset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ impl Dataset {
limit: Option<i64>,
offset: Option<i64>,
nearest: Option<&PyDict>,
batch_size: Option<usize>,
batch_readahead: Option<usize>,
fragment_readahead: Option<usize>,
scan_in_order: Option<bool>,
Expand All @@ -159,6 +160,9 @@ impl Dataset {
.limit(limit, offset)
.map_err(|err| PyValueError::new_err(err.to_string()))?;

if let Some(batch_size) = batch_size {
scanner.batch_size(batch_size);
}
if let Some(batch_readahead) = batch_readahead {
scanner.batch_readahead(batch_readahead);
}
Expand Down

0 comments on commit 656ace2

Please sign in to comment.