Skip to content

Commit

Permalink
fix lint for test
Browse files Browse the repository at this point in the history
  • Loading branch information
Xin Huang committed Feb 9, 2024
1 parent 0a15b65 commit c3f6d49
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 67 deletions.
2 changes: 1 addition & 1 deletion python/tests/core/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def test_add_read_with_branch(self, sample_dataset):
local_runner.append(sample_data1)

ds.add_branch("exp1")

assert local_runner.read_all() == sample_data1

create_time0 = datetime.utcfromtimestamp(
Expand Down
131 changes: 65 additions & 66 deletions python/tests/ray/test_runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,82 +86,81 @@ def test_write_read_dataset(self, sample_dataset, enable_row_range_block,
sample_dataset.add_branch("branch1")
runner = sample_dataset.ray(ray_options=RayOptions(
max_parallelism=4, enable_row_range_block=enable_row_range_block))
input_data0 = generate_data([1, 2, 3])
input_data1 = generate_data([4, 5])
input_data2 = generate_data([6, 7])
input_data3 = generate_data([8])
input_data4 = generate_data([9, 10, 11])
for branch in ["branch1", "main"]:
sample_dataset.set_current_branch(branch)
# Test append.
input_data0 = generate_data([1, 2, 3])
runner.append(input_data0)

assert_equal(
runner.read_all(batch_size=batch_size, version=branch if branch !="main" else None).sort_by("int64"),
input_data0.sort_by("int64"))

input_data1 = generate_data([4, 5])
input_data2 = generate_data([6, 7])
input_data3 = generate_data([8])
input_data4 = generate_data([9, 10, 11])

runner.append_from([
lambda: iter([input_data1, input_data2]), lambda: iter([input_data3]),
lambda: iter([input_data4])
sample_dataset.set_current_branch(branch)
# Test append.
runner.append(input_data0)

assert_equal(
runner.read_all(batch_size=batch_size).sort_by("int64"),
input_data0.sort_by("int64"))


runner.append_from([
lambda: iter([input_data1, input_data2]), lambda: iter([input_data3]),
lambda: iter([input_data4])
])

assert_equal(
runner.read_all(batch_size=batch_size, version=branch if branch !="main" else None).sort_by("int64"),
assert_equal(
runner.read_all(batch_size=batch_size).sort_by("int64"),
pa.concat_tables(
[input_data0, input_data1, input_data2, input_data3,
input_data4]).sort_by("int64"))

# Test insert.
result = runner.insert(generate_data([7, 12]))
assert result.state == JobResult.State.FAILED
assert "Primary key to insert already exist" in result.error_message

runner.upsert(generate_data([7, 12]))
assert_equal(
runner.read_all(batch_size=batch_size,version=branch if branch !="main" else None).sort_by("int64"),
pa.concat_tables([
input_data0, input_data1, input_data2, input_data3, input_data4,
generate_data([12])
]).sort_by("int64"))

# Test delete.
runner.delete(pc.field("int64") < 10)
assert_equal(
runner.read_all(batch_size=batch_size,version=branch if branch !="main" else None).sort_by("int64"),
pa.concat_tables([generate_data([10, 11, 12])]).sort_by("int64"))

# Test reading views.
view = sample_dataset.map_batches(fn=_sample_map_udf,
output_schema=sample_dataset.schema,
output_record_fields=["binary"])
assert_equal(
view.ray(DEFAULT_RAY_OPTIONS).read_all(
batch_size=batch_size).sort_by("int64"),
pa.concat_tables([
pa.Table.from_pydict({
"int64": [10, 11, 12],
"float64": [v / 10 + 1 for v in [10, 11, 12]],
"binary": [f"b{v}".encode("utf-8") for v in [10, 11, 12]]
})
]).sort_by("int64"))

# Test a transform on a view.
transform_on_view = view.map_batches(fn=_sample_map_udf,
output_schema=view.schema,
output_record_fields=["binary"])
assert_equal(
transform_on_view.ray(DEFAULT_RAY_OPTIONS).read_all(
batch_size=batch_size).sort_by("int64"),
result = runner.insert(generate_data([7, 12]))
assert result.state == JobResult.State.FAILED
assert "Primary key to insert already exist" in result.error_message

runner.upsert(generate_data([7, 12]))
assert_equal(
runner.read_all(batch_size=batch_size).sort_by("int64"),
pa.concat_tables([
input_data0, input_data1, input_data2, input_data3, input_data4,
generate_data([12])
]).sort_by("int64"))

# Test delete.
runner.delete(pc.field("int64") < 10)
assert_equal(
runner.read_all(batch_size=batch_size,version=branch if branch !="main" else None).sort_by("int64"),
pa.concat_tables([generate_data([10, 11, 12])]).sort_by("int64"))

# Test reading views.
view = sample_dataset.map_batches(fn=_sample_map_udf,
output_schema=sample_dataset.schema,
output_record_fields=["binary"])

assert_equal(
view.ray(DEFAULT_RAY_OPTIONS).read_all(
batch_size=batch_size).sort_by("int64"),
pa.concat_tables([
pa.Table.from_pydict({
"int64": [10, 11, 12],
"float64": [v / 10 + 1 for v in [10, 11, 12]],
"binary": [f"b{v}".encode("utf-8") for v in [10, 11, 12]]
})
]).sort_by("int64"))

# Test a transform on a view.
transform_on_view = view.map_batches(fn=_sample_map_udf,
output_schema=view.schema,
output_record_fields=["binary"])
assert_equal(
transform_on_view.ray(DEFAULT_RAY_OPTIONS).read_all(
batch_size=batch_size).sort_by("int64"),
pa.concat_tables([
pa.Table.from_pydict({
"int64": [10, 11, 12],
"float64": [v / 10 + 2 for v in [10, 11, 12]],
"binary": [f"b{v}".encode("utf-8") for v in [10, 11, 12]]
pa.Table.from_pydict({
"int64": [10, 11, 12],
"float64": [v / 10 + 2 for v in [10, 11, 12]],
"binary": [f"b{v}".encode("utf-8") for v in [10, 11, 12]]
})
]).sort_by("int64"))


]).sort_by("int64"))

@pytest.mark.parametrize("enable_row_range_block", [(True,), (False,)])
def test_read_batch_size(self, tmp_path, sample_schema,
Expand Down

0 comments on commit c3f6d49

Please sign in to comment.