diff --git a/python/tests/core/test_runners.py b/python/tests/core/test_runners.py index 9a1c7bb..8e5e573 100644 --- a/python/tests/core/test_runners.py +++ b/python/tests/core/test_runners.py @@ -246,7 +246,7 @@ def test_add_read_with_branch(self, sample_dataset): local_runner.append(sample_data1) ds.add_branch("exp1") - + assert local_runner.read_all() == sample_data1 create_time0 = datetime.utcfromtimestamp( diff --git a/python/tests/ray/test_runners.py b/python/tests/ray/test_runners.py index e85249a..9cfaeb1 100644 --- a/python/tests/ray/test_runners.py +++ b/python/tests/ray/test_runners.py @@ -86,82 +86,81 @@ def test_write_read_dataset(self, sample_dataset, enable_row_range_block, sample_dataset.add_branch("branch1") runner = sample_dataset.ray(ray_options=RayOptions( max_parallelism=4, enable_row_range_block=enable_row_range_block)) + input_data0 = generate_data([1, 2, 3]) + input_data1 = generate_data([4, 5]) + input_data2 = generate_data([6, 7]) + input_data3 = generate_data([8]) + input_data4 = generate_data([9, 10, 11]) for branch in ["branch1", "main"]: - sample_dataset.set_current_branch(branch) - # Test append. - input_data0 = generate_data([1, 2, 3]) - runner.append(input_data0) - - assert_equal( - runner.read_all(batch_size=batch_size, version=branch if branch !="main" else None).sort_by("int64"), - input_data0.sort_by("int64")) - - input_data1 = generate_data([4, 5]) - input_data2 = generate_data([6, 7]) - input_data3 = generate_data([8]) - input_data4 = generate_data([9, 10, 11]) - - runner.append_from([ - lambda: iter([input_data1, input_data2]), lambda: iter([input_data3]), - lambda: iter([input_data4]) + sample_dataset.set_current_branch(branch) + # Test append. + runner.append(input_data0) + + assert_equal( + runner.read_all(batch_size=batch_size).sort_by("int64"), + input_data0.sort_by("int64")) + + + runner.append_from([ + lambda: iter([input_data1, input_data2]), lambda: iter([input_data3]), + lambda: iter([input_data4]) ]) - assert_equal( - runner.read_all(batch_size=batch_size, version=branch if branch !="main" else None).sort_by("int64"), + assert_equal( + runner.read_all(batch_size=batch_size).sort_by("int64"), pa.concat_tables( [input_data0, input_data1, input_data2, input_data3, input_data4]).sort_by("int64")) # Test insert. - result = runner.insert(generate_data([7, 12])) - assert result.state == JobResult.State.FAILED - assert "Primary key to insert already exist" in result.error_message - - runner.upsert(generate_data([7, 12])) - assert_equal( - runner.read_all(batch_size=batch_size,version=branch if branch !="main" else None).sort_by("int64"), - pa.concat_tables([ - input_data0, input_data1, input_data2, input_data3, input_data4, - generate_data([12]) - ]).sort_by("int64")) - - # Test delete. - runner.delete(pc.field("int64") < 10) - assert_equal( - runner.read_all(batch_size=batch_size,version=branch if branch !="main" else None).sort_by("int64"), - pa.concat_tables([generate_data([10, 11, 12])]).sort_by("int64")) - - # Test reading views. - view = sample_dataset.map_batches(fn=_sample_map_udf, - output_schema=sample_dataset.schema, - output_record_fields=["binary"]) - assert_equal( - view.ray(DEFAULT_RAY_OPTIONS).read_all( - batch_size=batch_size).sort_by("int64"), - pa.concat_tables([ - pa.Table.from_pydict({ - "int64": [10, 11, 12], - "float64": [v / 10 + 1 for v in [10, 11, 12]], - "binary": [f"b{v}".encode("utf-8") for v in [10, 11, 12]] - }) - ]).sort_by("int64")) - - # Test a transform on a view. - transform_on_view = view.map_batches(fn=_sample_map_udf, - output_schema=view.schema, - output_record_fields=["binary"]) - assert_equal( - transform_on_view.ray(DEFAULT_RAY_OPTIONS).read_all( - batch_size=batch_size).sort_by("int64"), + result = runner.insert(generate_data([7, 12])) + assert result.state == JobResult.State.FAILED + assert "Primary key to insert already exist" in result.error_message + + runner.upsert(generate_data([7, 12])) + assert_equal( + runner.read_all(batch_size=batch_size).sort_by("int64"), + pa.concat_tables([ + input_data0, input_data1, input_data2, input_data3, input_data4, + generate_data([12]) + ]).sort_by("int64")) + + # Test delete. + runner.delete(pc.field("int64") < 10) + assert_equal( + runner.read_all(batch_size=batch_size,version=branch if branch !="main" else None).sort_by("int64"), + pa.concat_tables([generate_data([10, 11, 12])]).sort_by("int64")) + + # Test reading views. + view = sample_dataset.map_batches(fn=_sample_map_udf, + output_schema=sample_dataset.schema, + output_record_fields=["binary"]) + + assert_equal( + view.ray(DEFAULT_RAY_OPTIONS).read_all( + batch_size=batch_size).sort_by("int64"), + pa.concat_tables([ + pa.Table.from_pydict({ + "int64": [10, 11, 12], + "float64": [v / 10 + 1 for v in [10, 11, 12]], + "binary": [f"b{v}".encode("utf-8") for v in [10, 11, 12]] + }) + ]).sort_by("int64")) + + # Test a transform on a view. + transform_on_view = view.map_batches(fn=_sample_map_udf, + output_schema=view.schema, + output_record_fields=["binary"]) + assert_equal( + transform_on_view.ray(DEFAULT_RAY_OPTIONS).read_all( + batch_size=batch_size).sort_by("int64"), pa.concat_tables([ - pa.Table.from_pydict({ - "int64": [10, 11, 12], - "float64": [v / 10 + 2 for v in [10, 11, 12]], - "binary": [f"b{v}".encode("utf-8") for v in [10, 11, 12]] + pa.Table.from_pydict({ + "int64": [10, 11, 12], + "float64": [v / 10 + 2 for v in [10, 11, 12]], + "binary": [f"b{v}".encode("utf-8") for v in [10, 11, 12]] }) - ]).sort_by("int64")) - - + ]).sort_by("int64")) @pytest.mark.parametrize("enable_row_range_block", [(True,), (False,)]) def test_read_batch_size(self, tmp_path, sample_schema,