From 6ec1b30912836b593bc03fe721021c0b0ea036be Mon Sep 17 00:00:00 2001 From: doki Date: Fri, 25 Nov 2022 11:59:35 +0800 Subject: [PATCH 1/4] to pyarrow with schema --- arrow/src/pyarrow.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs index 7c365a4344a5..264a6ac93d3f 100644 --- a/arrow/src/pyarrow.rs +++ b/arrow/src/pyarrow.rs @@ -195,9 +195,12 @@ impl PyArrowConvert for RecordBatch { py_names.push(field.name()); } + let py_schema = schema.to_pyarrow(py)?; + let module = py.import("pyarrow")?; let class = module.getattr("RecordBatch")?; - let record = class.call_method1("from_arrays", (py_arrays, py_names))?; + let record = + class.call_method1("from_arrays", (py_arrays, py_names, py_schema))?; Ok(PyObject::from(record)) } From 4a240c6dce71f725dfc190c135737fb3d0f1b0b3 Mon Sep 17 00:00:00 2001 From: doki Date: Fri, 25 Nov 2022 12:05:36 +0800 Subject: [PATCH 2/4] only use schema --- arrow/src/pyarrow.rs | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs index 264a6ac93d3f..3a367db7c9c2 100644 --- a/arrow/src/pyarrow.rs +++ b/arrow/src/pyarrow.rs @@ -184,7 +184,6 @@ impl PyArrowConvert for RecordBatch { fn to_pyarrow(&self, py: Python) -> PyResult { let mut py_arrays = vec![]; - let mut py_names = vec![]; let schema = self.schema(); let fields = schema.fields().iter(); @@ -192,15 +191,13 @@ impl PyArrowConvert for RecordBatch { for (array, field) in columns.zip(fields) { py_arrays.push(array.data().to_pyarrow(py)?); - py_names.push(field.name()); } let py_schema = schema.to_pyarrow(py)?; let module = py.import("pyarrow")?; let class = module.getattr("RecordBatch")?; - let record = - class.call_method1("from_arrays", (py_arrays, py_names, py_schema))?; + let record = class.call_method1("from_arrays", (py_arrays, py_schema))?; Ok(PyObject::from(record)) } From af95eacce5340f5b75dee98a1827b081c66e63b9 Mon Sep 17 00:00:00 2001 From: doki Date: Fri, 25 Nov 2022 12:55:59 +0800 Subject: [PATCH 3/4] add test --- arrow/Cargo.toml | 4 ++++ arrow/src/pyarrow.rs | 3 +-- arrow/tests/pyarrow.rs | 42 ++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 arrow/tests/pyarrow.rs diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index ab8963b9c300..7ab5720296cd 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -268,3 +268,7 @@ required-features = ["test_utils", "ipc"] [[test]] name = "csv" required-features = ["csv", "chrono-tz"] + +[[test]] +name = "pyarrow" +required-features = ["pyarrow"] diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs index 3a367db7c9c2..5ddc3105a4ad 100644 --- a/arrow/src/pyarrow.rs +++ b/arrow/src/pyarrow.rs @@ -186,10 +186,9 @@ impl PyArrowConvert for RecordBatch { let mut py_arrays = vec![]; let schema = self.schema(); - let fields = schema.fields().iter(); let columns = self.columns().iter(); - for (array, field) in columns.zip(fields) { + for array in columns { py_arrays.push(array.data().to_pyarrow(py)?); } diff --git a/arrow/tests/pyarrow.rs b/arrow/tests/pyarrow.rs new file mode 100644 index 000000000000..4b1226c738f5 --- /dev/null +++ b/arrow/tests/pyarrow.rs @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, Int32Array, StringArray}; +use arrow::pyarrow::PyArrowConvert; +use arrow::record_batch::RecordBatch; +use pyo3::Python; +use std::sync::Arc; + +#[test] +fn test_to_pyarrow() { + pyo3::prepare_freethreaded_python(); + + let a: ArrayRef = Arc::new(Int32Array::from(vec![1, 2])); + let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b"])); + let input = RecordBatch::try_from_iter(vec![("a", a), ("b", b)]).unwrap(); + println!("input: {:?}", input); + + let res = Python::with_gil(|py| { + let py_input = input.to_pyarrow(py)?; + let records = RecordBatch::from_pyarrow(py_input.as_ref(py))?; + let py_records = records.to_pyarrow(py)?; + RecordBatch::from_pyarrow(py_records.as_ref(py)) + }) + .unwrap(); + + assert_eq!(input, res); +} From 89de3c465979ced0d1a613f5e9f10e94d3ace4b0 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Sat, 26 Nov 2022 12:28:05 +0000 Subject: [PATCH 4/4] Run python tests in CI --- .github/workflows/integration.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 3ece06b29238..656e56a652ca 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -149,13 +149,13 @@ jobs: virtualenv venv source venv/bin/activate pip install maturin toml pytest pytz pyarrow>=5.0 + - name: Run Rust tests + run: | + source venv/bin/activate + cargo test -p arrow --test pyarrow --features pyarrow - name: Run tests - env: - CARGO_HOME: "/home/runner/.cargo" - CARGO_TARGET_DIR: "/home/runner/target" run: | source venv/bin/activate - pushd arrow-pyarrow-integration-testing + cd arrow-pyarrow-integration-testing maturin develop pytest -v . - popd