Skip to content

Commit

Permalink
add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
coufon committed Jan 25, 2024
1 parent 183c349 commit b347a75
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 21 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/python-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
- name: Install test dependencies
run: |
python -m pip install --upgrade pip
pip install mypy pylint pytest mock
pip install mypy pylint pytest pytest-xdist mock
- name: Install runtime dependencies and Space
working-directory: ./python
run: |
Expand All @@ -58,4 +58,4 @@ jobs:
- name: Running tests
working-directory: ./python
run: |
pytest
pytest -n auto
15 changes: 11 additions & 4 deletions python/src/space/core/ops/read.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,14 +90,16 @@ def __init__(self,

def __iter__(self) -> Iterator[pa.Table]:
for file in self._file_set.index_files:
row_range_read = (file.selected_rows.end > 0)

# TODO: always loading the whole table is inefficient, to only load the
# required row groups.
index_data = pq.read_table(
self.full_path(file.path),
columns=self._selected_fields,
filters=self._options.filter_) # type: ignore[arg-type]

if file.selected_rows.end > 0:
if row_range_read:
length = file.selected_rows.end - file.selected_rows.start
index_data = index_data.slice(file.selected_rows.start, length)

Expand All @@ -116,10 +118,15 @@ def __iter__(self) -> Iterator[pa.Table]:
(column_id,
arrow.binary_field(self._record_fields_dict[field_id])))

for batch in index_data.to_batches(
max_chunksize=self._options.batch_size):
yield self._read_index_and_record(pa.table(batch), index_column_ids,
# The batch size enforcement is applied as row range.
if row_range_read:
yield self._read_index_and_record(index_data, index_column_ids,
record_columns)
else:
for batch in index_data.to_batches(
max_chunksize=self._options.batch_size):
yield self._read_index_and_record(pa.table(batch), index_column_ids,
record_columns)

def _read_index_and_record(
self, index_data: pa.Table, index_column_ids: List[int],
Expand Down
3 changes: 1 addition & 2 deletions python/src/space/core/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,8 +375,7 @@ def ray_dataset(self, ray_options: RayOptions,
read_options=read_options,
parallelism=ray_options.max_parallelism)

if (not ray_options.enable_index_file_row_range_block and
read_options.batch_size):
if (not ray_options.enable_row_range_block and read_options.batch_size):
return ds.repartition(math.ceil(ds.count() / read_options.batch_size))

return ds
Expand Down
2 changes: 1 addition & 1 deletion python/src/space/ray/data_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get_read_tasks(self, parallelism: int) -> List[ReadTask]:
for index_file in file_set.index_files:
num_rows = index_file.storage_statistics.num_rows

if (self._ray_options.enable_index_file_row_range_block and
if (self._ray_options.enable_row_range_block and
self._read_options.batch_size):
batch_size = self._read_options.batch_size
num_blocks = math.ceil(num_rows / batch_size)
Expand Down
2 changes: 1 addition & 1 deletion python/src/space/ray/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,4 @@ class RayOptions:
# If enabled, a Ray block size is capped by the provided read batch size.
# The cost is possible duplicated read of index files. It should be disabled
# when most data are stored in index files.
enable_index_file_row_range_block: bool = True
enable_row_range_block: bool = True
37 changes: 26 additions & 11 deletions python/tests/ray/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,26 @@ def _sample_partition_fn(range_: Range) -> List[Range]:

class TestRayReadWriteRunner:

def test_write_read_dataset(self, sample_dataset):
runner = sample_dataset.ray(ray_options=RayOptions(max_parallelism=4))
@pytest.mark.parametrize("enable_row_range_block,batch_size", [
(True, None),
(False, None),
(True, 1),
(False, 1),
(True, 3),
(False, 3),
])
def test_write_read_dataset(self, sample_dataset, enable_row_range_block,
batch_size):
runner = sample_dataset.ray(ray_options=RayOptions(
max_parallelism=4, enable_row_range_block=enable_row_range_block))

# Test append.
input_data0 = generate_data([1, 2, 3])
runner.append(input_data0)

assert_equal(runner.read_all().sort_by("int64"),
input_data0.sort_by("int64"))
assert_equal(
runner.read_all(batch_size=batch_size).sort_by("int64"),
input_data0.sort_by("int64"))

input_data1 = generate_data([4, 5])
input_data2 = generate_data([6, 7])
Expand All @@ -86,7 +97,7 @@ def test_write_read_dataset(self, sample_dataset):
])

assert_equal(
runner.read_all().sort_by("int64"),
runner.read_all(batch_size=batch_size).sort_by("int64"),
pa.concat_tables(
[input_data0, input_data1, input_data2, input_data3,
input_data4]).sort_by("int64"))
Expand All @@ -98,7 +109,7 @@ def test_write_read_dataset(self, sample_dataset):

runner.upsert(generate_data([7, 12]))
assert_equal(
runner.read_all().sort_by("int64"),
runner.read_all(batch_size=batch_size).sort_by("int64"),
pa.concat_tables([
input_data0, input_data1, input_data2, input_data3, input_data4,
generate_data([12])
Expand All @@ -107,7 +118,7 @@ def test_write_read_dataset(self, sample_dataset):
# Test delete.
runner.delete(pc.field("int64") < 10)
assert_equal(
runner.read_all().sort_by("int64"),
runner.read_all(batch_size=batch_size).sort_by("int64"),
pa.concat_tables([generate_data([10, 11, 12])]).sort_by("int64"))

# Test reading views.
Expand All @@ -120,7 +131,7 @@ def _sample_map_udf(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
output_schema=sample_dataset.schema,
output_record_fields=["binary"])
assert_equal(
view.ray().read_all().sort_by("int64"),
view.ray().read_all(batch_size=batch_size).sort_by("int64"),
pa.concat_tables([
pa.Table.from_pydict({
"int64": [10, 11, 12],
Expand All @@ -134,7 +145,8 @@ def _sample_map_udf(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
output_schema=view.schema,
output_record_fields=["binary"])
assert_equal(
transform_on_view.ray().read_all().sort_by("int64"),
transform_on_view.ray().read_all(
batch_size=batch_size).sort_by("int64"),
pa.concat_tables([
pa.Table.from_pydict({
"int64": [10, 11, 12],
Expand All @@ -143,12 +155,15 @@ def _sample_map_udf(batch: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
})
]).sort_by("int64"))

def test_read_batch_size(self, tmp_path, sample_schema):
@pytest.mark.parametrize("enable_row_range_block", [(True,), (False,)])
def test_read_batch_size(self, tmp_path, sample_schema,
enable_row_range_block):
ds = Dataset.create(str(tmp_path / f"dataset_{random_id()}"),
sample_schema,
primary_keys=["int64"],
record_fields=["binary"])
runner = ds.ray(ray_options=RayOptions(max_parallelism=1))
runner = ds.ray(ray_options=RayOptions(
max_parallelism=1, enable_row_range_block=enable_row_range_block))
data = generate_data(range(0, 60))
runner.append(data)

Expand Down

0 comments on commit b347a75

Please sign in to comment.