From b0432bb2ec315535d472d669fe42e93a1a9e615a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 9 Jun 2021 18:55:02 +0200 Subject: [PATCH 01/12] FFI bridge for Schema, Field and DataType --- arrow/src/datatypes/ffi.rs | 171 +++++++++++++++++++++++++++++++++++++ arrow/src/datatypes/mod.rs | 2 + arrow/src/ffi.rs | 123 ++------------------------ 3 files changed, 180 insertions(+), 116 deletions(-) create mode 100644 arrow/src/datatypes/ffi.rs diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs new file mode 100644 index 000000000000..1bf337ec1dfe --- /dev/null +++ b/arrow/src/datatypes/ffi.rs @@ -0,0 +1,171 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Contains functionality to load an ArrayData from the C Data Interface + +use std::convert::TryFrom; + +use crate::{ + datatypes::{DataType, Field, Schema, TimeUnit}, + error::{ArrowError, Result}, + ffi, +}; + +type CArrowSchema = ffi::FFI_ArrowSchema; + +impl TryFrom<&CArrowSchema> for DataType { + type Error = ArrowError; + + /// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings + fn try_from(c_schema: &CArrowSchema) -> Result { + let dtype = match c_schema.format() { + "n" => DataType::Null, + "b" => DataType::Boolean, + "c" => DataType::Int8, + "C" => DataType::UInt8, + "s" => DataType::Int16, + "S" => DataType::UInt16, + "i" => DataType::Int32, + "I" => DataType::UInt32, + "l" => DataType::Int64, + "L" => DataType::UInt64, + "e" => DataType::Float16, + "f" => DataType::Float32, + "g" => DataType::Float64, + "z" => DataType::Binary, + "Z" => DataType::LargeBinary, + "u" => DataType::Utf8, + "U" => DataType::LargeUtf8, + "tdD" => DataType::Date32, + "tdm" => DataType::Date64, + "tts" => DataType::Time32(TimeUnit::Second), + "ttm" => DataType::Time32(TimeUnit::Millisecond), + "ttu" => DataType::Time64(TimeUnit::Microsecond), + "ttn" => DataType::Time64(TimeUnit::Nanosecond), + "+l" => { + let c_child = c_schema.child(0); + DataType::List(Box::new(Field::try_from(c_child)?)) + } + "+L" => { + let c_child = c_schema.child(0); + DataType::LargeList(Box::new(Field::try_from(c_child)?)) + } + "+s" => { + let fields = c_schema.children().map(Field::try_from); + DataType::Struct(fields.collect::>>()?) + } + // Parametrized types, requiring string parse + other => { + match other.splitn(2, ':').collect::>().as_slice() { + // Decimal types in format "d:precision,scale" or "d:precision,scale,bitWidth" + ["d", extra] => { + match extra.splitn(3, ',').collect::>().as_slice() { + [precision, scale] => { + let parsed_precision = precision.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The decimal type requires an integer precision".to_string(), + ) + })?; + let parsed_scale = scale.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The decimal type requires an integer scale".to_string(), + ) + })?; + DataType::Decimal(parsed_precision, parsed_scale) + }, + [precision, scale, bits] => { + if *bits != "128" { + return Err(ArrowError::CDataInterface("Only 128 bit wide decimal is supported in the Rust implementation".to_string())); + } + let parsed_precision = precision.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The decimal type requires an integer precision".to_string(), + ) + })?; + let parsed_scale = scale.parse::().map_err(|_| { + ArrowError::CDataInterface( + "The decimal type requires an integer scale".to_string(), + ) + })?; + DataType::Decimal(parsed_precision, parsed_scale) + } + _ => { + return Err(ArrowError::CDataInterface(format!( + "The decimal pattern \"d:{:?}\" is not supported in the Rust implementation", + extra + ))) + } + } + } + + // Timestamps in format "tts:" and "tts:America/New_York" for no timezones and timezones resp. + ["tss", ""] => DataType::Timestamp(TimeUnit::Second, None), + ["tsm", ""] => DataType::Timestamp(TimeUnit::Millisecond, None), + ["tsu", ""] => DataType::Timestamp(TimeUnit::Microsecond, None), + ["tsn", ""] => DataType::Timestamp(TimeUnit::Nanosecond, None), + ["tss", tz] => { + DataType::Timestamp(TimeUnit::Second, Some(tz.to_string())) + } + ["tsm", tz] => { + DataType::Timestamp(TimeUnit::Millisecond, Some(tz.to_string())) + } + ["tsu", tz] => { + DataType::Timestamp(TimeUnit::Microsecond, Some(tz.to_string())) + } + ["tsn", tz] => { + DataType::Timestamp(TimeUnit::Nanosecond, Some(tz.to_string())) + } + _ => { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{:?}\" is still not supported in Rust implementation", + other + ))) + } + } + } + }; + Ok(dtype) + } +} + +impl TryFrom<&CArrowSchema> for Field { + type Error = ArrowError; + + fn try_from(c_schema: &CArrowSchema) -> Result { + let field = Field::new( + c_schema.name(), + DataType::try_from(c_schema)?, + c_schema.nullable(), + ); + Ok(field) + } +} + +impl TryFrom<&CArrowSchema> for Schema { + type Error = ArrowError; + + fn try_from(c_schema: &CArrowSchema) -> Result { + let fields = c_schema.children().map(Field::try_from); + let schema = Schema::new(fields.collect::>>()?); + Ok(schema) + } +} + +#[cfg(test)] +mod tests { + // TODO +} diff --git a/arrow/src/datatypes/mod.rs b/arrow/src/datatypes/mod.rs index 6a2d0dcfe27e..51b33dc667e3 100644 --- a/arrow/src/datatypes/mod.rs +++ b/arrow/src/datatypes/mod.rs @@ -36,6 +36,8 @@ mod types; pub use types::*; mod datatype; pub use datatype::*; +mod ffi; +pub use ffi::*; /// A reference-counted reference to a [`Schema`](crate::datatypes::Schema). pub type SchemaRef = Arc; diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index b804dd2db74a..183d825f6d6f 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -77,6 +77,7 @@ To export an array, create an `ArrowArray` using [ArrowArray::try_new]. */ use std::{ + convert::TryFrom, ffi::CStr, ffi::CString, iter, @@ -217,6 +218,10 @@ impl FFI_ArrowSchema { unsafe { self.children.add(index).as_ref().unwrap().as_ref().unwrap() } } + pub fn children(&self) -> impl Iterator { + (0..self.n_children as usize).map(move |i| self.child(i)) + } + pub fn nullable(&self) -> bool { (self.flags / 2) & 1 == 1 } @@ -231,120 +236,6 @@ impl Drop for FFI_ArrowSchema { } } -/// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings -fn to_field(schema: &FFI_ArrowSchema) -> Result { - let data_type = match schema.format() { - "n" => DataType::Null, - "b" => DataType::Boolean, - "c" => DataType::Int8, - "C" => DataType::UInt8, - "s" => DataType::Int16, - "S" => DataType::UInt16, - "i" => DataType::Int32, - "I" => DataType::UInt32, - "l" => DataType::Int64, - "L" => DataType::UInt64, - "e" => DataType::Float16, - "f" => DataType::Float32, - "g" => DataType::Float64, - "z" => DataType::Binary, - "Z" => DataType::LargeBinary, - "u" => DataType::Utf8, - "U" => DataType::LargeUtf8, - "tdD" => DataType::Date32, - "tdm" => DataType::Date64, - "tts" => DataType::Time32(TimeUnit::Second), - "ttm" => DataType::Time32(TimeUnit::Millisecond), - "ttu" => DataType::Time64(TimeUnit::Microsecond), - "ttn" => DataType::Time64(TimeUnit::Nanosecond), - "+l" => { - let child = schema.child(0); - DataType::List(Box::new(to_field(child)?)) - } - "+L" => { - let child = schema.child(0); - DataType::LargeList(Box::new(to_field(child)?)) - } - "+s" => { - let children = (0..schema.n_children as usize) - .map(|x| to_field(schema.child(x))) - .collect::>>()?; - DataType::Struct(children) - } - // Parametrized types, requiring string parse - other => { - match other.splitn(2, ':').collect::>().as_slice() { - // Decimal types in format "d:precision,scale" or "d:precision,scale,bitWidth" - ["d", extra] => { - match extra.splitn(3, ',').collect::>().as_slice() { - [precision, scale] => { - let parsed_precision = precision.parse::().map_err(|_| { - ArrowError::CDataInterface( - "The decimal type requires an integer precision".to_string(), - ) - })?; - let parsed_scale = scale.parse::().map_err(|_| { - ArrowError::CDataInterface( - "The decimal type requires an integer scale".to_string(), - ) - })?; - DataType::Decimal(parsed_precision, parsed_scale) - }, - [precision, scale, bits] => { - if *bits != "128" { - return Err(ArrowError::CDataInterface("Only 128 bit wide decimal is supported in the Rust implementation".to_string())); - } - let parsed_precision = precision.parse::().map_err(|_| { - ArrowError::CDataInterface( - "The decimal type requires an integer precision".to_string(), - ) - })?; - let parsed_scale = scale.parse::().map_err(|_| { - ArrowError::CDataInterface( - "The decimal type requires an integer scale".to_string(), - ) - })?; - DataType::Decimal(parsed_precision, parsed_scale) - } - _ => { - return Err(ArrowError::CDataInterface(format!( - "The decimal pattern \"d:{:?}\" is not supported in the Rust implementation", - extra - ))) - } - } - } - - // Timestamps in format "tts:" and "tts:America/New_York" for no timezones and timezones resp. - ["tss", ""] => DataType::Timestamp(TimeUnit::Second, None), - ["tsm", ""] => DataType::Timestamp(TimeUnit::Millisecond, None), - ["tsu", ""] => DataType::Timestamp(TimeUnit::Microsecond, None), - ["tsn", ""] => DataType::Timestamp(TimeUnit::Nanosecond, None), - ["tss", tz] => { - DataType::Timestamp(TimeUnit::Second, Some(tz.to_string())) - } - ["tsm", tz] => { - DataType::Timestamp(TimeUnit::Millisecond, Some(tz.to_string())) - } - ["tsu", tz] => { - DataType::Timestamp(TimeUnit::Microsecond, Some(tz.to_string())) - } - ["tsn", tz] => { - DataType::Timestamp(TimeUnit::Nanosecond, Some(tz.to_string())) - } - - _ => { - return Err(ArrowError::CDataInterface(format!( - "The datatype \"{:?}\" is still not supported in Rust implementation", - other - ))) - } - } - } - }; - Ok(Field::new(schema.name(), data_type, schema.nullable())) -} - /// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings fn to_format(data_type: &DataType) -> Result { Ok(match data_type { @@ -814,7 +705,7 @@ pub struct ArrowArrayChild<'a> { impl ArrowArrayRef for ArrowArray { /// the data_type as declared in the schema fn data_type(&self) -> Result { - to_field(&self.schema).map(|x| x.data_type().clone()) + DataType::try_from(self.schema.as_ref()) } fn array(&self) -> &FFI_ArrowArray { @@ -833,7 +724,7 @@ impl ArrowArrayRef for ArrowArray { impl<'a> ArrowArrayRef for ArrowArrayChild<'a> { /// the data_type as declared in the schema fn data_type(&self) -> Result { - to_field(self.schema).map(|x| x.data_type().clone()) + DataType::try_from(self.schema) } fn array(&self) -> &FFI_ArrowArray { From 22a55badbb3530799abc54f614df96a74cff9542 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Thu, 10 Jun 2021 04:58:38 +0200 Subject: [PATCH 02/12] Factor out conversion to datatypes/ffi.rs --- arrow/src/datatypes/ffi.rs | 82 ++++++++++++++++- arrow/src/ffi.rs | 181 ++++++++++++------------------------- 2 files changed, 135 insertions(+), 128 deletions(-) diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs index 1bf337ec1dfe..7d3e87cb84ce 100644 --- a/arrow/src/datatypes/ffi.rs +++ b/arrow/src/datatypes/ffi.rs @@ -142,19 +142,91 @@ impl TryFrom<&CArrowSchema> for DataType { } } +impl TryFrom<&DataType> for CArrowSchema { + type Error = ArrowError; + + /// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings + fn try_from(dtype: &DataType) -> Result { + let format = match dtype { + DataType::Null => "n".to_string(), + DataType::Boolean => "b".to_string(), + DataType::Int8 => "c".to_string(), + DataType::UInt8 => "C".to_string(), + DataType::Int16 => "s".to_string(), + DataType::UInt16 => "S".to_string(), + DataType::Int32 => "i".to_string(), + DataType::UInt32 => "I".to_string(), + DataType::Int64 => "l".to_string(), + DataType::UInt64 => "L".to_string(), + DataType::Float16 => "e".to_string(), + DataType::Float32 => "f".to_string(), + DataType::Float64 => "g".to_string(), + DataType::Binary => "z".to_string(), + DataType::LargeBinary => "Z".to_string(), + DataType::Utf8 => "u".to_string(), + DataType::LargeUtf8 => "U".to_string(), + DataType::Decimal(precision, scale) => format!("d:{},{}", precision, scale), + DataType::Date32 => "tdD".to_string(), + DataType::Date64 => "tdm".to_string(), + DataType::Time32(TimeUnit::Second) => "tts".to_string(), + DataType::Time32(TimeUnit::Millisecond) => "ttm".to_string(), + DataType::Time64(TimeUnit::Microsecond) => "ttu".to_string(), + DataType::Time64(TimeUnit::Nanosecond) => "ttn".to_string(), + DataType::Timestamp(TimeUnit::Second, None) => "tss:".to_string(), + DataType::Timestamp(TimeUnit::Millisecond, None) => "tsm:".to_string(), + DataType::Timestamp(TimeUnit::Microsecond, None) => "tsu:".to_string(), + DataType::Timestamp(TimeUnit::Nanosecond, None) => "tsn:".to_string(), + DataType::Timestamp(TimeUnit::Second, Some(tz)) => format!("tss:{}", tz), + DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => format!("tsm:{}", tz), + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => format!("tsu:{}", tz), + DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => format!("tsn:{}", tz), + DataType::List(_) => "+l".to_string(), + DataType::LargeList(_) => "+L".to_string(), + DataType::Struct(_) => "+s".to_string(), + other => { + return Err(ArrowError::CDataInterface(format!( + "The datatype \"{:?}\" is still not supported in Rust implementation", + other + ))) + } + }; + // allocate and hold the children + let children = match dtype { + DataType::List(child) | DataType::LargeList(child) => { + vec![CArrowSchema::try_from(child.as_ref())?] + } + DataType::Struct(fields) => fields + .iter() + .map(CArrowSchema::try_from) + .collect::>>()?, + _ => vec![], + }; + CArrowSchema::try_new(&format, children) + } +} + impl TryFrom<&CArrowSchema> for Field { type Error = ArrowError; fn try_from(c_schema: &CArrowSchema) -> Result { - let field = Field::new( - c_schema.name(), - DataType::try_from(c_schema)?, - c_schema.nullable(), - ); + let dtype = DataType::try_from(c_schema)?; + // TODO: validate that it has a struct type + let field = Field::new(c_schema.name(), dtype, c_schema.nullable()); Ok(field) } } +impl TryFrom<&Field> for CArrowSchema { + type Error = ArrowError; + + fn try_from(field: &Field) -> Result { + CArrowSchema::try_from(field.data_type()) + .unwrap() + .with_name(field.name()) + // with_flags + } +} + impl TryFrom<&CArrowSchema> for Schema { type Error = ArrowError; diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index 183d825f6d6f..c08e4b2ab8b1 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -82,6 +82,7 @@ use std::{ ffi::CString, iter, mem::size_of, + os::raw::{c_char, c_void}, ptr::{self, NonNull}, sync::Arc, }; @@ -92,27 +93,25 @@ use crate::datatypes::{DataType, Field, TimeUnit}; use crate::error::{ArrowError, Result}; use crate::util::bit_util; -#[allow(dead_code)] -struct SchemaPrivateData { - field: Field, - children_ptr: Box<[*mut FFI_ArrowSchema]>, -} - /// ABI-compatible struct for `ArrowSchema` from C Data Interface /// See /// This was created by bindgen #[repr(C)] #[derive(Debug)] pub struct FFI_ArrowSchema { - format: *const ::std::os::raw::c_char, - name: *const ::std::os::raw::c_char, - metadata: *const ::std::os::raw::c_char, + format: *const c_char, + name: *const c_char, + metadata: *const c_char, flags: i64, n_children: i64, children: *mut *mut FFI_ArrowSchema, dictionary: *mut FFI_ArrowSchema, - release: ::std::option::Option, - private_data: *mut ::std::os::raw::c_void, + release: Option, + private_data: *mut c_void, +} + +struct SchemaPrivateData { + children: Box<[*mut FFI_ArrowSchema]>, } // callback used to drop [FFI_ArrowSchema] when it is exported. @@ -123,11 +122,16 @@ unsafe extern "C" fn release_schema(schema: *mut FFI_ArrowSchema) { let schema = &mut *schema; // take ownership back to release it. - CString::from_raw(schema.format as *mut std::os::raw::c_char); - CString::from_raw(schema.name as *mut std::os::raw::c_char); - let private = Box::from_raw(schema.private_data as *mut SchemaPrivateData); - for child in private.children_ptr.iter() { - let _ = Box::from_raw(*child); + CString::from_raw(schema.format as *mut c_char); + if !schema.name.is_null() { + CString::from_raw(schema.name as *mut c_char); + } + if !schema.private_data.is_null() { + let private_data = Box::from_raw(schema.private_data as *mut SchemaPrivateData); + for child in private_data.children.iter() { + drop(Box::from_raw(*child)) + } + drop(private_data); } schema.release = None; @@ -135,54 +139,43 @@ unsafe extern "C" fn release_schema(schema: *mut FFI_ArrowSchema) { impl FFI_ArrowSchema { /// create a new [`Ffi_ArrowSchema`]. This fails if the fields' [`DataType`] is not supported. - fn try_new(field: Field) -> Result { - let format = to_format(field.data_type())?; - let name = field.name().clone(); - - // allocate (and hold) the children - let children_vec = match field.data_type() { - DataType::List(field) => { - vec![Box::new(FFI_ArrowSchema::try_new(field.as_ref().clone())?)] - } - DataType::LargeList(field) => { - vec![Box::new(FFI_ArrowSchema::try_new(field.as_ref().clone())?)] - } - DataType::Struct(fields) => fields - .iter() - .map(|field| Ok(Box::new(FFI_ArrowSchema::try_new(field.clone())?))) - .collect::>>()?, - _ => vec![], - }; - // note: this cannot be done along with the above because the above is fallible and this op leaks. - let children_ptr = children_vec + pub fn try_new(format: &str, children: Vec) -> Result { + let mut this = Self::empty(); + + // note: this op leaks. + let mut children_ptr = children .into_iter() + .map(Box::new) .map(Box::into_raw) .collect::>(); - let n_children = children_ptr.len() as i64; - let flags = field.is_nullable() as i64 * 2; + this.format = CString::new(format).unwrap().into_raw(); + this.release = Some(release_schema); + this.n_children = children_ptr.len() as i64; + this.children = children_ptr.as_mut_ptr(); - let mut private = Box::new(SchemaPrivateData { - field, - children_ptr, + let private_data = Box::new(SchemaPrivateData { + children: children_ptr, }); + this.private_data = Box::into_raw(private_data) as *mut c_void; + if this.n_children > 0 { + // perhaps move the set command here + } - // - Ok(FFI_ArrowSchema { - format: CString::new(format).unwrap().into_raw(), - name: CString::new(name).unwrap().into_raw(), - metadata: std::ptr::null_mut(), - flags, - n_children, - children: private.children_ptr.as_mut_ptr(), - dictionary: std::ptr::null_mut(), - release: Some(release_schema), - private_data: Box::into_raw(private) as *mut ::std::os::raw::c_void, - }) + Ok(this) } + pub fn with_name(mut self, name: &str) -> Result { + self.name = CString::new(name).unwrap().into_raw(); + Ok(self) + } + + // pub fn with_flags() {} + // pub fn with_dictionary() {} + // pub fn with_metadata() {} + /// create an empty [FFI_ArrowSchema] - fn empty() -> Self { + pub fn empty() -> Self { Self { format: std::ptr::null_mut(), name: std::ptr::null_mut(), @@ -209,12 +202,14 @@ impl FFI_ArrowSchema { pub fn name(&self) -> &str { assert!(!self.name.is_null()); // safe because the lifetime of `self.name` equals `self` - unsafe { CStr::from_ptr(self.name) }.to_str().unwrap() + unsafe { CStr::from_ptr(self.name) } + .to_str() + .expect("The external API has a non-utf8 as name") } pub fn child(&self, index: usize) -> &Self { assert!(index < self.n_children as usize); - assert!(!self.name.is_null()); + // assert!(!self.name.is_null()); unsafe { self.children.add(index).as_ref().unwrap().as_ref().unwrap() } } @@ -236,64 +231,6 @@ impl Drop for FFI_ArrowSchema { } } -/// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings -fn to_format(data_type: &DataType) -> Result { - Ok(match data_type { - DataType::Null => "n", - DataType::Boolean => "b", - DataType::Int8 => "c", - DataType::UInt8 => "C", - DataType::Int16 => "s", - DataType::UInt16 => "S", - DataType::Int32 => "i", - DataType::UInt32 => "I", - DataType::Int64 => "l", - DataType::UInt64 => "L", - DataType::Float16 => "e", - DataType::Float32 => "f", - DataType::Float64 => "g", - DataType::Binary => "z", - DataType::LargeBinary => "Z", - DataType::Utf8 => "u", - DataType::LargeUtf8 => "U", - DataType::Decimal(precision, scale) => { - return Ok(format!("d:{},{}", precision, scale)) - } - DataType::Date32 => "tdD", - DataType::Date64 => "tdm", - DataType::Time32(TimeUnit::Second) => "tts", - DataType::Time32(TimeUnit::Millisecond) => "ttm", - DataType::Time64(TimeUnit::Microsecond) => "ttu", - DataType::Time64(TimeUnit::Nanosecond) => "ttn", - DataType::Timestamp(TimeUnit::Second, None) => "tss:", - DataType::Timestamp(TimeUnit::Millisecond, None) => "tsm:", - DataType::Timestamp(TimeUnit::Microsecond, None) => "tsu:", - DataType::Timestamp(TimeUnit::Nanosecond, None) => "tsn:", - DataType::Timestamp(TimeUnit::Second, Some(tz)) => { - return Ok(format!("tss:{}", tz)) - } - DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => { - return Ok(format!("tsm:{}", tz)) - } - DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => { - return Ok(format!("tsu:{}", tz)) - } - DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => { - return Ok(format!("tsn:{}", tz)) - } - DataType::List(_) => "+l", - DataType::LargeList(_) => "+L", - DataType::Struct(_) => "+s", - z => { - return Err(ArrowError::CDataInterface(format!( - "The datatype \"{:?}\" is still not supported in Rust implementation", - z - ))) - } - } - .to_string()) -} - // returns the number of bits that buffer `i` (in the C data interface) is expected to have. // This is set by the Arrow specification fn bit_width(data_type: &DataType, i: usize) -> Result { @@ -373,16 +310,16 @@ pub struct FFI_ArrowArray { pub(crate) offset: i64, pub(crate) n_buffers: i64, pub(crate) n_children: i64, - pub(crate) buffers: *mut *const ::std::os::raw::c_void, + pub(crate) buffers: *mut *const c_void, children: *mut *mut FFI_ArrowArray, dictionary: *mut FFI_ArrowArray, - release: ::std::option::Option, + release: Option, // When exported, this MUST contain everything that is owned by this array. // for example, any buffer pointed to in `buffers` must be here, as well as the `buffers` pointer // itself. // In other words, everything in [FFI_ArrowArray] must be owned by `private_data` and can assume // that they do not outlive `private_data`. - private_data: *mut ::std::os::raw::c_void, + private_data: *mut c_void, } impl Drop for FFI_ArrowArray { @@ -412,7 +349,7 @@ unsafe extern "C" fn release_array(array: *mut FFI_ArrowArray) { struct PrivateData { buffers: Vec>, - buffers_ptr: Box<[*const std::os::raw::c_void]>, + buffers_ptr: Box<[*const c_void]>, children: Box<[*mut FFI_ArrowArray]>, } @@ -433,7 +370,7 @@ impl FFI_ArrowArray { .iter() .map(|maybe_buffer| match maybe_buffer { // note that `raw_data` takes into account the buffer's offset - Some(b) => b.as_ptr() as *const std::os::raw::c_void, + Some(b) => b.as_ptr() as *const c_void, None => std::ptr::null(), }) .collect::>(); @@ -463,7 +400,7 @@ impl FFI_ArrowArray { children: private_data.children.as_mut_ptr(), dictionary: std::ptr::null_mut(), release: Some(release_array), - private_data: Box::into_raw(private_data) as *mut ::std::os::raw::c_void, + private_data: Box::into_raw(private_data) as *mut c_void, } } @@ -746,10 +683,8 @@ impl ArrowArray { /// See safety of [ArrowArray] #[allow(clippy::too_many_arguments)] pub unsafe fn try_new(data: ArrayData) -> Result { - let field = Field::new("", data.data_type().clone(), data.null_count() != 0); let array = Arc::new(FFI_ArrowArray::new(&data)); - let schema = Arc::new(FFI_ArrowSchema::try_new(field)?); - + let schema = Arc::new(FFI_ArrowSchema::try_from(data.data_type())?); Ok(ArrowArray { array, schema }) } From d7819d5d89263eef3058e7f9dcc674a483c41fcd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Thu, 10 Jun 2021 21:40:57 +0200 Subject: [PATCH 03/12] Add flags --- arrow/Cargo.toml | 1 + arrow/src/datatypes/ffi.rs | 8 +++++++- arrow/src/ffi.rs | 25 ++++++++++++++++++++----- 3 files changed, 28 insertions(+), 6 deletions(-) diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index 0ed2a4526211..4a1016aab026 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -52,6 +52,7 @@ hex = "0.4" prettytable-rs = { version = "0.8.0", optional = true } lexical-core = "^0.7" multiversion = "0.6.1" +bitflags = "1.2.1" [features] default = ["csv", "ipc"] diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs index 7d3e87cb84ce..78b44756a44a 100644 --- a/arrow/src/datatypes/ffi.rs +++ b/arrow/src/datatypes/ffi.rs @@ -220,10 +220,16 @@ impl TryFrom<&Field> for CArrowSchema { type Error = ArrowError; fn try_from(field: &Field) -> Result { + let flags = if field.is_nullable() { + ffi::Flags::NULLABLE + } else { + ffi::Flags::empty() + }; CArrowSchema::try_from(field.data_type()) .unwrap() .with_name(field.name()) - // with_flags + .unwrap() + .with_flags(flags) } } diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index c08e4b2ab8b1..cca710aa9f9c 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -87,12 +87,22 @@ use std::{ sync::Arc, }; +use bitflags::bitflags; + use crate::array::ArrayData; use crate::buffer::Buffer; -use crate::datatypes::{DataType, Field, TimeUnit}; +use crate::datatypes::DataType; use crate::error::{ArrowError, Result}; use crate::util::bit_util; +bitflags! { + pub struct Flags: i64 { + const DICTIONARY_ORDERED = 0b00000001; + const NULLABLE = 0b00000010; + const MAP_KEYS_SORTED = 0b00000100; + } +} + /// ABI-compatible struct for `ArrowSchema` from C Data Interface /// See /// This was created by bindgen @@ -158,9 +168,6 @@ impl FFI_ArrowSchema { children: children_ptr, }); this.private_data = Box::into_raw(private_data) as *mut c_void; - if this.n_children > 0 { - // perhaps move the set command here - } Ok(this) } @@ -170,7 +177,11 @@ impl FFI_ArrowSchema { Ok(self) } - // pub fn with_flags() {} + pub fn with_flags(mut self, flags: Flags) -> Result { + self.flags = flags.bits(); + Ok(self) + } + // pub fn with_dictionary() {} // pub fn with_metadata() {} @@ -207,6 +218,10 @@ impl FFI_ArrowSchema { .expect("The external API has a non-utf8 as name") } + pub fn flags(&self) -> Option { + Flags::from_bits(self.flags) + } + pub fn child(&self, index: usize) -> &Self { assert!(index < self.n_children as usize); // assert!(!self.name.is_null()); From 99a443d54ec711f97ceac668351b898547495b37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Fri, 11 Jun 2021 18:19:45 +0200 Subject: [PATCH 04/12] Rust tests --- arrow/src/datatypes/ffi.rs | 110 ++++++++++++++++++++++++++++++++++--- 1 file changed, 101 insertions(+), 9 deletions(-) diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs index 78b44756a44a..ddca8d8bea16 100644 --- a/arrow/src/datatypes/ffi.rs +++ b/arrow/src/datatypes/ffi.rs @@ -22,11 +22,9 @@ use std::convert::TryFrom; use crate::{ datatypes::{DataType, Field, Schema, TimeUnit}, error::{ArrowError, Result}, - ffi, + ffi::{FFI_ArrowSchema as CArrowSchema, Flags}, }; -type CArrowSchema = ffi::FFI_ArrowSchema; - impl TryFrom<&CArrowSchema> for DataType { type Error = ArrowError; @@ -210,7 +208,6 @@ impl TryFrom<&CArrowSchema> for Field { fn try_from(c_schema: &CArrowSchema) -> Result { let dtype = DataType::try_from(c_schema)?; - // TODO: validate that it has a struct type let field = Field::new(c_schema.name(), dtype, c_schema.nullable()); Ok(field) } @@ -221,9 +218,9 @@ impl TryFrom<&Field> for CArrowSchema { fn try_from(field: &Field) -> Result { let flags = if field.is_nullable() { - ffi::Flags::NULLABLE + Flags::NULLABLE } else { - ffi::Flags::empty() + Flags::empty() }; CArrowSchema::try_from(field.data_type()) .unwrap() @@ -237,13 +234,108 @@ impl TryFrom<&CArrowSchema> for Schema { type Error = ArrowError; fn try_from(c_schema: &CArrowSchema) -> Result { - let fields = c_schema.children().map(Field::try_from); - let schema = Schema::new(fields.collect::>>()?); - Ok(schema) + // interpret it as a struct type then extract its fields + let dtype = DataType::try_from(c_schema)?; + if let DataType::Struct(fields) = dtype { + Ok(Schema::new(fields)) + } else { + Err(ArrowError::CDataInterface(format!( + "Unable to interpret C data struct as a Schema" + ))) + } + } +} + +impl TryFrom<&Schema> for CArrowSchema { + type Error = ArrowError; + + fn try_from(schema: &Schema) -> Result { + let dtype = DataType::Struct(schema.fields().clone()); + let c_schema = CArrowSchema::try_from(&dtype)?; + Ok(c_schema) } } #[cfg(test)] mod tests { + use super::*; + use crate::datatypes::{DataType, Field, TimeUnit}; + use crate::error::Result; + use std::convert::TryFrom; + + fn round_trip_type(dtype: DataType) -> Result<()> { + let c_schema = CArrowSchema::try_from(&dtype)?; + let restored = DataType::try_from(&c_schema)?; + assert_eq!(restored, dtype); + Ok(()) + } + + fn round_trip_field(field: Field) -> Result<()> { + let c_schema = CArrowSchema::try_from(&field)?; + let restored = Field::try_from(&c_schema)?; + assert_eq!(restored, field); + Ok(()) + } + + fn round_trip_schema(schema: Schema) -> Result<()> { + let c_schema = CArrowSchema::try_from(&schema)?; + let restored = Schema::try_from(&c_schema)?; + assert_eq!(restored, schema); + Ok(()) + } + + #[test] + fn test_type() -> Result<()> { + round_trip_type(DataType::Int64)?; + round_trip_type(DataType::UInt64)?; + round_trip_type(DataType::Float64)?; + round_trip_type(DataType::Date64)?; + round_trip_type(DataType::Time64(TimeUnit::Nanosecond))?; + round_trip_type(DataType::Utf8)?; + round_trip_type(DataType::List(Box::new(Field::new( + "a", + DataType::Int16, + false, + ))))?; + round_trip_type(DataType::Struct(vec![Field::new( + "a", + DataType::Utf8, + true, + )]))?; + Ok(()) + } + + #[test] + fn test_field() -> Result<()> { + let dtype = DataType::Struct(vec![Field::new("a", DataType::Utf8, true)]); + round_trip_field(Field::new("test", dtype, true))?; + Ok(()) + } + + #[test] + fn test_schema() -> Result<()> { + let schema = Schema::new(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("address", DataType::Utf8, false), + Field::new("priority", DataType::UInt8, false), + ]); + round_trip_schema(schema)?; + + // test that we can interpret struct types as schema + let dtype = DataType::Struct(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Int16, false), + ]); + let c_schema = CArrowSchema::try_from(&dtype)?; + let schema = Schema::try_from(&c_schema)?; + assert_eq!(schema.fields().len(), 2); + + // test that we assert the input type + let c_schema = CArrowSchema::try_from(&DataType::Float64)?; + let result = Schema::try_from(&c_schema); + assert_eq!(result.is_err(), true); + Ok(()) + } + // TODO } From bcb0b55d9a926ff831942812220955db83cc0a4c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Mon, 14 Jun 2021 15:53:30 +0200 Subject: [PATCH 05/12] Test datatypes from the python test suite --- .github/workflows/rust.yml | 4 +- arrow-pyarrow-integration-testing/src/lib.rs | 109 ++++- .../tests/test_sql.py | 380 +++++++++++------- arrow/src/array/ffi.rs | 2 +- arrow/src/datatypes/ffi.rs | 112 ++++-- arrow/src/ffi.rs | 14 +- 6 files changed, 405 insertions(+), 216 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 559c7c8a3961..c8fa6c840f07 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -325,9 +325,9 @@ jobs: python -m venv venv source venv/bin/activate - pip install maturin==0.8.2 toml==0.10.1 pyarrow==1.0.0 pytz + pip install maturin==0.8.2 toml==0.10.1 pyarrow==1.0.0 pytz pytest maturin develop - python -m unittest discover tests + pytest -v . # test the arrow crate builds against wasm32 in stable rust wasm32-build: diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index 5b5462d9c151..98bf5a1b62fd 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -18,6 +18,7 @@ //! This library demonstrates a minimal usage of Rust's C data interface to pass //! arrays from and to Python. +use std::convert::TryFrom; use std::error; use std::fmt; use std::sync::Arc; @@ -28,8 +29,10 @@ use pyo3::{libc::uintptr_t, prelude::*}; use arrow::array::{make_array_from_raw, ArrayRef, Int64Array}; use arrow::compute::kernels; +use arrow::datatypes::{DataType, Field}; use arrow::error::ArrowError; use arrow::ffi; +use arrow::ffi::FFI_ArrowSchema; /// an error that bridges ArrowError with a Python error #[derive(Debug)] @@ -68,7 +71,78 @@ impl From for PyErr { } } -fn to_rust(ob: PyObject, py: Python) -> PyResult { +#[pyclass] +struct PyDataType { + inner: DataType, +} + +#[pymethods] +impl PyDataType { + #[staticmethod] + fn from_pyarrow(value: &PyAny) -> PyResult { + let c_schema = FFI_ArrowSchema::empty(); + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; + let dtype = DataType::try_from(&c_schema).map_err(PyO3ArrowError::from)?; + Ok(PyDataType { inner: dtype }) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let c_schema = + FFI_ArrowSchema::try_from(&self.inner).map_err(PyO3ArrowError::from)?; + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + let module = py.import("pyarrow")?; + let class = module.getattr("DataType")?; + let dtype = class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?; + Ok(dtype.into()) + } +} + +#[pyclass] +struct PyField { + inner: Field, +} + +#[pymethods] +impl PyField { + #[staticmethod] + fn from_pyarrow(value: &PyAny) -> PyResult { + let c_schema = FFI_ArrowSchema::empty(); + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; + let field = Field::try_from(&c_schema).map_err(PyO3ArrowError::from)?; + Ok(PyField { inner: field }) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let c_schema = + FFI_ArrowSchema::try_from(&self.inner).map_err(PyO3ArrowError::from)?; + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + let module = py.import("pyarrow")?; + let class = module.getattr("Field")?; + let dtype = class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?; + Ok(dtype.into()) + } +} + +impl<'source> FromPyObject<'source> for PyDataType { + fn extract(value: &'source PyAny) -> PyResult { + PyDataType::from_pyarrow(value) + } +} + +impl<'source> FromPyObject<'source> for PyField { + fn extract(value: &'source PyAny) -> PyResult { + PyField::from_pyarrow(value) + } +} + +// struct PyField(Field); +// struct PySchema(Schema); + +// fn type_to_rust(ob: PyObject, py: Python) -> PyResult { + +fn array_to_rust(ob: PyObject, py: Python) -> PyResult { // prepare a pointer to receive the Array struct let (array_pointer, schema_pointer) = ffi::ArrowArray::into_raw(unsafe { ffi::ArrowArray::empty() }); @@ -86,7 +160,7 @@ fn to_rust(ob: PyObject, py: Python) -> PyResult { Ok(array) } -fn to_py(array: ArrayRef, py: Python) -> PyResult { +fn array_to_py(array: ArrayRef, py: Python) -> PyResult { let (array_pointer, schema_pointer) = array.to_raw().map_err(|e| PyO3ArrowError::from(e))?; @@ -99,11 +173,14 @@ fn to_py(array: ArrayRef, py: Python) -> PyResult { Ok(array.to_object(py)) } +/// Casts `array` to the target type +//#[pyfunction] + /// Returns `array + array` of an int64 array. #[pyfunction] fn double(array: PyObject, py: Python) -> PyResult { // import - let array = to_rust(array, py)?; + let array = array_to_rust(array, py)?; // perform some operation let array = @@ -118,7 +195,7 @@ fn double(array: PyObject, py: Python) -> PyResult { let array = Arc::new(array); // export - to_py(array, py) + array_to_py(array, py) } /// calls a lambda function that receives and returns an array @@ -130,11 +207,9 @@ fn double_py(lambda: PyObject, py: Python) -> PyResult { let expected = Arc::new(Int64Array::from(vec![Some(2), None, Some(6)])) as ArrayRef; // to py - let array = to_py(array, py)?; - - let array = lambda.call1(py, (array,))?; - - let array = to_rust(array, py)?; + let pyarray = array_to_py(array, py)?; + let pyarray = lambda.call1(py, (pyarray,))?; + let array = array_to_rust(pyarray, py)?; Ok(array == expected) } @@ -143,42 +218,44 @@ fn double_py(lambda: PyObject, py: Python) -> PyResult { #[pyfunction] fn substring(array: PyObject, start: i64, py: Python) -> PyResult { // import - let array = to_rust(array, py)?; + let array = array_to_rust(array, py)?; // substring let array = kernels::substring::substring(array.as_ref(), start, &None) .map_err(|e| PyO3ArrowError::from(e))?; // export - to_py(array, py) + array_to_py(array, py) } /// Returns the concatenate #[pyfunction] fn concatenate(array: PyObject, py: Python) -> PyResult { // import - let array = to_rust(array, py)?; + let array = array_to_rust(array, py)?; // concat let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()]) .map_err(|e| PyO3ArrowError::from(e))?; // export - to_py(array, py) + array_to_py(array, py) } /// Converts to rust and back to python #[pyfunction] -fn round_trip(array: PyObject, py: Python) -> PyResult { +fn round_trip(pyarray: PyObject, py: Python) -> PyResult { // import - let array = to_rust(array, py)?; + let array = array_to_rust(pyarray, py)?; // export - to_py(array, py) + array_to_py(array, py) } #[pymodule] fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; m.add_wrapped(wrap_pyfunction!(double))?; m.add_wrapped(wrap_pyfunction!(double_py))?; m.add_wrapped(wrap_pyfunction!(substring))?; diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index 5524c54ec178..e90cd376e51a 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -16,156 +16,242 @@ # specific language governing permissions and limitations # under the License. -import unittest -from datetime import date, datetime -from decimal import Decimal +import contextlib + +import pytest +import pyarrow as pa + +from arrow_pyarrow_integration_testing import PyDataType, PyField +import arrow_pyarrow_integration_testing as rust -import arrow_pyarrow_integration_testing -import pyarrow from pytz import timezone -class TestCase(unittest.TestCase): - def test_primitive_python(self): - """ - Python -> Rust -> Python - """ - old_allocated = pyarrow.total_allocated_bytes() - a = pyarrow.array([1, 2, 3]) - b = arrow_pyarrow_integration_testing.double(a) - self.assertEqual(b, pyarrow.array([2, 4, 6])) - del a - del b - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - def test_primitive_rust(self): - """ - Rust -> Python -> Rust - """ - old_allocated = pyarrow.total_allocated_bytes() - - def double(array): - array = array.to_pylist() - return pyarrow.array([x * 2 if x is not None else None for x in array]) - - is_correct = arrow_pyarrow_integration_testing.double_py(double) - self.assertTrue(is_correct) - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - def test_string_python(self): - """ - Python -> Rust -> Python - """ - old_allocated = pyarrow.total_allocated_bytes() - a = pyarrow.array(["a", None, "ccc"]) - b = arrow_pyarrow_integration_testing.substring(a, 1) - self.assertEqual(b, pyarrow.array(["", None, "cc"])) - del a - del b - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - def test_time32_python(self): - """ - Python -> Rust -> Python - """ - old_allocated = pyarrow.total_allocated_bytes() - a = pyarrow.array([None, 1, 2], pyarrow.time32("s")) - b = arrow_pyarrow_integration_testing.concatenate(a) - expected = pyarrow.array([None, 1, 2] + [None, 1, 2], pyarrow.time32("s")) - self.assertEqual(b, expected) - del a - del b - del expected - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - def test_date32_python(self): - """ - Python -> Rust -> Python - """ - old_allocated = pyarrow.total_allocated_bytes() - py_array = [None, date(1990, 3, 9), date(2021, 6, 20)] - a = pyarrow.array(py_array, pyarrow.date32()) - b = arrow_pyarrow_integration_testing.concatenate(a) - expected = pyarrow.array(py_array + py_array, pyarrow.date32()) - self.assertEqual(b, expected) - del a - del b - del expected - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - def test_timestamp_python(self): - """ - Python -> Rust -> Python - """ - old_allocated = pyarrow.total_allocated_bytes() - py_array = [ - None, - datetime(2021, 1, 1, 1, 1, 1, 1), - datetime(2020, 3, 9, 1, 1, 1, 1), +@contextlib.contextmanager +def no_pyarrow_leak(): + # No leak of C++ memory + old_allocation = pa.total_allocated_bytes() + try: + yield + finally: + assert pa.total_allocated_bytes() == old_allocation + + +@pytest.fixture(autouse=True) +def assert_pyarrow_leak(): + # automatically applied to all test cases + with no_pyarrow_leak(): + yield + + +_supported_pyarrow_types = [ + pa.null(), + pa.bool_(), + pa.int32(), + pa.time32("s"), + pa.time64("us"), + pa.date32(), + pa.float16(), + pa.float32(), + pa.float64(), + pa.string(), + pa.binary(), + pa.large_string(), + pa.large_binary(), + pa.list_(pa.int32()), + pa.large_list(pa.uint16()), + pa.struct( + [ + pa.field("a", pa.int32()), + pa.field("b", pa.int8()), + pa.field("c", pa.string()), ] - a = pyarrow.array(py_array, pyarrow.timestamp("us")) - b = arrow_pyarrow_integration_testing.concatenate(a) - expected = pyarrow.array(py_array + py_array, pyarrow.timestamp("us")) - self.assertEqual(b, expected) - del a - del b - del expected - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - def test_timestamp_tz_python(self): - """ - Python -> Rust -> Python - """ - old_allocated = pyarrow.total_allocated_bytes() - py_array = [ - None, - datetime(2021, 1, 1, 1, 1, 1, 1, tzinfo=timezone("America/New_York")), - datetime(2020, 3, 9, 1, 1, 1, 1, tzinfo=timezone("America/New_York")), + ), + pa.struct( + [ + pa.field("a", pa.int32(), nullable=False), + pa.field("b", pa.int8(), nullable=False), + pa.field("c", pa.string()), ] - a = pyarrow.array(py_array, pyarrow.timestamp("us", tz="America/New_York")) - b = arrow_pyarrow_integration_testing.concatenate(a) - expected = pyarrow.array( - py_array + py_array, pyarrow.timestamp("us", tz="America/New_York") - ) - self.assertEqual(b, expected) - del a - del b - del expected - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - def test_decimal_python(self): - """ - Python -> Rust -> Python - """ - old_allocated = pyarrow.total_allocated_bytes() - py_array = [round(Decimal(123.45), 2), round(Decimal(-123.45), 2), None] - a = pyarrow.array(py_array, pyarrow.decimal128(6, 2)) - b = arrow_pyarrow_integration_testing.round_trip(a) - self.assertEqual(a, b) - del a - del b - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) - - def test_list_array(self): - """ - Python -> Rust -> Python - """ - old_allocated = pyarrow.total_allocated_bytes() - a = pyarrow.array([[], None, [1, 2], [4, 5, 6]], pyarrow.list_(pyarrow.int64())) - b = arrow_pyarrow_integration_testing.round_trip(a) - - b.validate(full=True) - assert a.to_pylist() == b.to_pylist() - assert a.type == b.type - del a - del b - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) + ), +] + +_unsupported_pyarrow_types = [ + pa.timestamp("us"), + pa.timestamp("us", tz="UTC"), + pa.timestamp("us", tz="Europe/Paris"), + pa.duration("s"), + pa.decimal128(19, 4), + pa.decimal256(76, 38), + pa.binary(10), + pa.list_(pa.int32(), 2), + pa.map_(pa.string(), pa.int32()), + pa.union( + [pa.field("a", pa.binary(10)), pa.field("b", pa.string())], + mode=pa.lib.UnionMode_DENSE, + ), + pa.union( + [pa.field("a", pa.binary(10)), pa.field("b", pa.string())], + mode=pa.lib.UnionMode_DENSE, + type_codes=[4, 8], + ), + pa.union( + [pa.field("a", pa.binary(10)), pa.field("b", pa.string())], + mode=pa.lib.UnionMode_SPARSE, + ), + pa.union( + [ + pa.field("a", pa.binary(10), nullable=False), + pa.field("b", pa.string()), + ], + mode=pa.lib.UnionMode_SPARSE, + ), +] + + +@pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str) +def test_type_roundtrip(pyarrow_type): + ty = PyDataType.from_pyarrow(pyarrow_type) + restored = ty.to_pyarrow() + assert restored == pyarrow_type + assert restored is not pyarrow_type + + +@pytest.mark.parametrize("pyarrow_type", _unsupported_pyarrow_types, ids=str) +def test_type_roundtrip_raises(pyarrow_type): + with pytest.raises(Exception): + PyDataType.from_pyarrow(pyarrow_type) + + +def test_dictionary_type_roundtrip(): + # the dictionary type conversion is incomplete + pyarrow_type = pa.dictionary(pa.int32(), pa.string()) + ty = PyDataType.from_pyarrow(pyarrow_type) + assert ty.to_pyarrow() == pa.int32() + + +# Missing implementation in pyarrow +# @pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str) +# def test_field_roundtrip(pyarrow_type): +# for nullable in [True, False]: +# pyarrow_field = pa.field("test", pyarrow_type, nullable=nullable) +# field = PyField.from_pyarrow(pyarrow_field) +# assert field.to_pyarrow() == pyarrow_field + + +def test_primitive_python(): + """ + Python -> Rust -> Python + """ + a = pa.array([1, 2, 3]) + b = rust.double(a) + assert b == pa.array([2, 4, 6]) + del a + del b + + +def test_primitive_rust(): + """ + Rust -> Python -> Rust + """ + + def double(array): + array = array.to_pylist() + return pa.array([x * 2 if x is not None else None for x in array]) + + is_correct = rust.double_py(double) + assert is_correct + + +def test_string_python(): + """ + Python -> Rust -> Python + """ + a = pa.array(["a", None, "ccc"]) + b = rust.substring(a, 1) + assert b == pa.array(["", None, "cc"]) + del a + del b + + +def test_time32_python(): + """ + Python -> Rust -> Python + """ + a = pa.array([None, 1, 2], pa.time32("s")) + b = rust.concatenate(a) + expected = pa.array([None, 1, 2] + [None, 1, 2], pa.time32("s")) + assert b == expected + del a + del b + del expected + + +def test_list_array(): + """ + Python -> Rust -> Python + """ + a = pa.array([[], None, [1, 2], [4, 5, 6]], pa.list_(pa.int64())) + b = rust.round_trip(a) + b.validate(full=True) + assert a.to_pylist() == b.to_pylist() + assert a.type == b.type + del a + del b + + +def test_timestamp_python(self): + """ + Python -> Rust -> Python + """ + old_allocated = pyarrow.total_allocated_bytes() + py_array = [ + None, + datetime(2021, 1, 1, 1, 1, 1, 1), + datetime(2020, 3, 9, 1, 1, 1, 1), + ] + a = pyarrow.array(py_array, pyarrow.timestamp("us")) + b = arrow_pyarrow_integration_testing.concatenate(a) + expected = pyarrow.array(py_array + py_array, pyarrow.timestamp("us")) + self.assertEqual(b, expected) + del a + del b + del expected + # No leak of C++ memory + self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) + +def test_timestamp_tz_python(self): + """ + Python -> Rust -> Python + """ + old_allocated = pyarrow.total_allocated_bytes() + py_array = [ + None, + datetime(2021, 1, 1, 1, 1, 1, 1, tzinfo=timezone("America/New_York")), + datetime(2020, 3, 9, 1, 1, 1, 1, tzinfo=timezone("America/New_York")), + ] + a = pyarrow.array(py_array, pyarrow.timestamp("us", tz="America/New_York")) + b = arrow_pyarrow_integration_testing.concatenate(a) + expected = pyarrow.array( + py_array + py_array, pyarrow.timestamp("us", tz="America/New_York") + ) + self.assertEqual(b, expected) + del a + del b + del expected + # No leak of C++ memory + self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) + +def test_decimal_python(self): + """ + Python -> Rust -> Python + """ + old_allocated = pyarrow.total_allocated_bytes() + py_array = [round(Decimal(123.45), 2), round(Decimal(-123.45), 2), None] + a = pyarrow.array(py_array, pyarrow.decimal128(6, 2)) + b = arrow_pyarrow_integration_testing.round_trip(a) + self.assertEqual(a, b) + del a + del b + # No leak of C++ memory + self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) diff --git a/arrow/src/array/ffi.rs b/arrow/src/array/ffi.rs index 847649ce1264..a404186a8f18 100644 --- a/arrow/src/array/ffi.rs +++ b/arrow/src/array/ffi.rs @@ -22,7 +22,7 @@ use std::convert::TryFrom; use crate::{ error::{ArrowError, Result}, ffi, - ffi::ArrowArrayRef, + ffi::ArrowArrayRef }; use super::ArrayData; diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs index ddca8d8bea16..489128991e7f 100644 --- a/arrow/src/datatypes/ffi.rs +++ b/arrow/src/datatypes/ffi.rs @@ -22,14 +22,14 @@ use std::convert::TryFrom; use crate::{ datatypes::{DataType, Field, Schema, TimeUnit}, error::{ArrowError, Result}, - ffi::{FFI_ArrowSchema as CArrowSchema, Flags}, + ffi::{FFI_ArrowSchema, Flags}, }; -impl TryFrom<&CArrowSchema> for DataType { +impl TryFrom<&FFI_ArrowSchema> for DataType { type Error = ArrowError; /// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings - fn try_from(c_schema: &CArrowSchema) -> Result { + fn try_from(c_schema: &FFI_ArrowSchema) -> Result { let dtype = match c_schema.format() { "n" => DataType::Null, "b" => DataType::Boolean, @@ -140,7 +140,33 @@ impl TryFrom<&CArrowSchema> for DataType { } } -impl TryFrom<&DataType> for CArrowSchema { +impl TryFrom<&FFI_ArrowSchema> for Field { + type Error = ArrowError; + + fn try_from(c_schema: &FFI_ArrowSchema) -> Result { + let dtype = DataType::try_from(c_schema)?; + let field = Field::new(c_schema.name(), dtype, c_schema.nullable()); + Ok(field) + } +} + +impl TryFrom<&FFI_ArrowSchema> for Schema { + type Error = ArrowError; + + fn try_from(c_schema: &FFI_ArrowSchema) -> Result { + // interpret it as a struct type then extract its fields + let dtype = DataType::try_from(c_schema)?; + if let DataType::Struct(fields) = dtype { + Ok(Schema::new(fields)) + } else { + Err(ArrowError::CDataInterface(format!( + "Unable to interpret C data struct as a Schema" + ))) + } + } +} + +impl TryFrom<&DataType> for FFI_ArrowSchema { type Error = ArrowError; /// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings @@ -191,29 +217,19 @@ impl TryFrom<&DataType> for CArrowSchema { // allocate and hold the children let children = match dtype { DataType::List(child) | DataType::LargeList(child) => { - vec![CArrowSchema::try_from(child.as_ref())?] + vec![FFI_ArrowSchema::try_from(child.as_ref())?] } DataType::Struct(fields) => fields .iter() - .map(CArrowSchema::try_from) + .map(FFI_ArrowSchema::try_from) .collect::>>()?, _ => vec![], }; - CArrowSchema::try_new(&format, children) - } -} - -impl TryFrom<&CArrowSchema> for Field { - type Error = ArrowError; - - fn try_from(c_schema: &CArrowSchema) -> Result { - let dtype = DataType::try_from(c_schema)?; - let field = Field::new(c_schema.name(), dtype, c_schema.nullable()); - Ok(field) + FFI_ArrowSchema::try_new(&format, children) } } -impl TryFrom<&Field> for CArrowSchema { +impl TryFrom<&Field> for FFI_ArrowSchema { type Error = ArrowError; fn try_from(field: &Field) -> Result { @@ -222,37 +238,43 @@ impl TryFrom<&Field> for CArrowSchema { } else { Flags::empty() }; - CArrowSchema::try_from(field.data_type()) - .unwrap() - .with_name(field.name()) - .unwrap() + FFI_ArrowSchema::try_from(field.data_type())? + .with_name(field.name())? .with_flags(flags) } } -impl TryFrom<&CArrowSchema> for Schema { +impl TryFrom<&Schema> for FFI_ArrowSchema { type Error = ArrowError; - fn try_from(c_schema: &CArrowSchema) -> Result { - // interpret it as a struct type then extract its fields - let dtype = DataType::try_from(c_schema)?; - if let DataType::Struct(fields) = dtype { - Ok(Schema::new(fields)) - } else { - Err(ArrowError::CDataInterface(format!( - "Unable to interpret C data struct as a Schema" - ))) - } + fn try_from(schema: &Schema) -> Result { + let dtype = DataType::Struct(schema.fields().clone()); + let c_schema = FFI_ArrowSchema::try_from(&dtype)?; + Ok(c_schema) } } -impl TryFrom<&Schema> for CArrowSchema { +impl TryFrom for FFI_ArrowSchema { type Error = ArrowError; - fn try_from(schema: &Schema) -> Result { - let dtype = DataType::Struct(schema.fields().clone()); - let c_schema = CArrowSchema::try_from(&dtype)?; - Ok(c_schema) + fn try_from(dtype: DataType) -> Result { + FFI_ArrowSchema::try_from(&dtype) + } +} + +impl TryFrom for FFI_ArrowSchema { + type Error = ArrowError; + + fn try_from(field: Field) -> Result { + FFI_ArrowSchema::try_from(&field) + } +} + +impl TryFrom for FFI_ArrowSchema { + type Error = ArrowError; + + fn try_from(schema: Schema) -> Result { + FFI_ArrowSchema::try_from(&schema) } } @@ -264,26 +286,30 @@ mod tests { use std::convert::TryFrom; fn round_trip_type(dtype: DataType) -> Result<()> { - let c_schema = CArrowSchema::try_from(&dtype)?; + let c_schema = FFI_ArrowSchema::try_from(&dtype)?; let restored = DataType::try_from(&c_schema)?; assert_eq!(restored, dtype); Ok(()) } fn round_trip_field(field: Field) -> Result<()> { - let c_schema = CArrowSchema::try_from(&field)?; + let c_schema = FFI_ArrowSchema::try_from(&field)?; let restored = Field::try_from(&c_schema)?; assert_eq!(restored, field); Ok(()) } fn round_trip_schema(schema: Schema) -> Result<()> { - let c_schema = CArrowSchema::try_from(&schema)?; + let c_schema = FFI_ArrowSchema::try_from(&schema)?; let restored = Schema::try_from(&c_schema)?; assert_eq!(restored, schema); Ok(()) } + // fn roundtrip>(expected: T) { + // let c_schema: FFI_ArrowSchema = expected.try_into().unwrap(); + // } + #[test] fn test_type() -> Result<()> { round_trip_type(DataType::Int64)?; @@ -326,12 +352,12 @@ mod tests { Field::new("a", DataType::Utf8, true), Field::new("b", DataType::Int16, false), ]); - let c_schema = CArrowSchema::try_from(&dtype)?; + let c_schema = FFI_ArrowSchema::try_from(&dtype)?; let schema = Schema::try_from(&c_schema)?; assert_eq!(schema.fields().len(), 2); // test that we assert the input type - let c_schema = CArrowSchema::try_from(&DataType::Float64)?; + let c_schema = FFI_ArrowSchema::try_from(&DataType::Float64)?; let result = Schema::try_from(&c_schema); assert_eq!(result.is_err(), true); Ok(()) diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index cca710aa9f9c..f1f208e089fe 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -330,10 +330,10 @@ pub struct FFI_ArrowArray { dictionary: *mut FFI_ArrowArray, release: Option, // When exported, this MUST contain everything that is owned by this array. - // for example, any buffer pointed to in `buffers` must be here, as well as the `buffers` pointer - // itself. - // In other words, everything in [FFI_ArrowArray] must be owned by `private_data` and can assume - // that they do not outlive `private_data`. + // for example, any buffer pointed to in `buffers` must be here, as well + // as the `buffers` pointer itself. + // In other words, everything in [FFI_ArrowArray] must be owned by + // `private_data` and can assume that they do not outlive `private_data`. private_data: *mut c_void, } @@ -354,7 +354,7 @@ unsafe extern "C" fn release_array(array: *mut FFI_ArrowArray) { let array = &mut *array; // take ownership of `private_data`, therefore dropping it` - let private = Box::from_raw(array.private_data as *mut PrivateData); + let private = Box::from_raw(array.private_data as *mut ArrayPrivateData); for child in private.children.iter() { let _ = Box::from_raw(*child); } @@ -362,7 +362,7 @@ unsafe extern "C" fn release_array(array: *mut FFI_ArrowArray) { array.release = None; } -struct PrivateData { +struct ArrayPrivateData { buffers: Vec>, buffers_ptr: Box<[*const c_void]>, children: Box<[*mut FFI_ArrowArray]>, @@ -399,7 +399,7 @@ impl FFI_ArrowArray { // create the private data owning everything. // any other data must be added here, e.g. via a struct, to track lifetime. - let mut private_data = Box::new(PrivateData { + let mut private_data = Box::new(ArrayPrivateData { buffers, buffers_ptr, children, From 4a400c3a83054c871373f2ac3fada99090cc2eca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 16 Jun 2021 19:16:29 +0200 Subject: [PATCH 06/12] Install a pinned nightly pyarrow wheel --- .github/workflows/integration.yml | 53 ++++++++++++++++++++++++++++++- .github/workflows/rust.yml | 46 --------------------------- 2 files changed, 52 insertions(+), 47 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index cab6dd34caac..a713d05e04bf 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -23,7 +23,7 @@ on: jobs: - docker: + integration: name: Integration Test runs-on: ubuntu-latest steps: @@ -46,3 +46,54 @@ jobs: run: pip install -e dev/archery[docker] - name: Execute Docker Build run: archery docker run -e ARCHERY_INTEGRATION_WITH_RUST=1 conda-integration + + # test FFI against the C-Data interface exposed by pyarrow + pyarrow-integration-test: + name: Test Pyarrow C Data Interface + runs-on: ubuntu-latest + strategy: + matrix: + rust: [stable] + steps: + - uses: actions/checkout@v2 + with: + submodules: true + - name: Setup Rust toolchain + run: | + rustup toolchain install ${{ matrix.rust }} + rustup default ${{ matrix.rust }} + rustup component add rustfmt clippy + - name: Cache Cargo + uses: actions/cache@v2 + with: + path: /home/runner/.cargo + key: cargo-maturin-cache- + - name: Cache Rust dependencies + uses: actions/cache@v2 + with: + path: /home/runner/target + # this key is not equal because maturin uses different compilation flags. + key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{ matrix.rust }}- + - uses: actions/setup-python@v2 + with: + python-version: '3.7' + - name: Upgrade pip and setuptools + run: pip install --upgrade pip setuptools wheel + - name: Install python dependencies + run: pip install maturin==0.8.2 toml==0.10.1 pytest pytz + - name: Install nightly pyarrow wheel + # this points to a nightly pyarrow build containing neccessary + # API for integration testing (https://github.com/apache/arrow/pull/10529) + # the hardcoded version is wrong and should be removed either + # after https://issues.apache.org/jira/browse/ARROW-13083 + # gets fixes or pyarrow 5.0 gets released + hardcoded version is wrong, bot contains + run: pip install --index-url https://pypi.fury.io/arrow-nightlies/ pyarrow==3.1.0.dev1030 + - name: Run tests + env: + CARGO_HOME: "/home/runner/.cargo" + CARGO_TARGET_DIR: "/home/runner/target" + working-directory: arrow-pyarrow-integration-testing + run: | + maturin develop + pytest -v . diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c8fa6c840f07..a041afc8b217 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -283,52 +283,6 @@ jobs: continue-on-error: true run: bash <(curl -s https://codecov.io/bash) - # test FFI against the C-Data interface exposed by pyarrow - pyarrow-integration-test: - name: Test Pyarrow C Data Interface - runs-on: ubuntu-latest - strategy: - matrix: - rust: [stable] - steps: - - uses: actions/checkout@v2 - with: - submodules: true - - name: Setup Rust toolchain - run: | - rustup toolchain install ${{ matrix.rust }} - rustup default ${{ matrix.rust }} - rustup component add rustfmt clippy - - name: Cache Cargo - uses: actions/cache@v2 - with: - path: /home/runner/.cargo - key: cargo-maturin-cache- - - name: Cache Rust dependencies - uses: actions/cache@v2 - with: - path: /home/runner/target - # this key is not equal because maturin uses different compilation flags. - key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{ matrix.rust }}- - - uses: actions/setup-python@v2 - with: - python-version: '3.7' - - name: Install Python dependencies - run: python -m pip install --upgrade pip setuptools wheel - - name: Run tests - run: | - export CARGO_HOME="/home/runner/.cargo" - export CARGO_TARGET_DIR="/home/runner/target" - - cd arrow-pyarrow-integration-testing - - python -m venv venv - source venv/bin/activate - - pip install maturin==0.8.2 toml==0.10.1 pyarrow==1.0.0 pytz pytest - maturin develop - pytest -v . - # test the arrow crate builds against wasm32 in stable rust wasm32-build: name: Build wasm32 on AMD64 Rust ${{ matrix.rust }} From 2b80ad2c3fc47ea783f9b7379698e88743c83fac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Thu, 17 Jun 2021 17:31:07 +0200 Subject: [PATCH 07/12] Python tests for Field and Schema --- arrow-pyarrow-integration-testing/src/lib.rs | 76 ++++++++++++------- .../tests/test_sql.py | 28 +++++-- arrow/src/array/ffi.rs | 2 +- arrow/src/datatypes/ffi.rs | 12 +-- 4 files changed, 74 insertions(+), 44 deletions(-) diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index 98bf5a1b62fd..13dee1d691a3 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -29,7 +29,7 @@ use pyo3::{libc::uintptr_t, prelude::*}; use arrow::array::{make_array_from_raw, ArrayRef, Int64Array}; use arrow::compute::kernels; -use arrow::datatypes::{DataType, Field}; +use arrow::datatypes::{DataType, Field, Schema}; use arrow::error::ArrowError; use arrow::ffi; use arrow::ffi::FFI_ArrowSchema; @@ -76,6 +76,16 @@ struct PyDataType { inner: DataType, } +#[pyclass] +struct PyField { + inner: Field, +} + +#[pyclass] +struct PySchema { + inner: Schema, +} + #[pymethods] impl PyDataType { #[staticmethod] @@ -84,7 +94,7 @@ impl PyDataType { let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; let dtype = DataType::try_from(&c_schema).map_err(PyO3ArrowError::from)?; - Ok(PyDataType { inner: dtype }) + Ok(Self { inner: dtype }) } fn to_pyarrow(&self, py: Python) -> PyResult { @@ -98,11 +108,6 @@ impl PyDataType { } } -#[pyclass] -struct PyField { - inner: Field, -} - #[pymethods] impl PyField { #[staticmethod] @@ -111,7 +116,7 @@ impl PyField { let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; let field = Field::try_from(&c_schema).map_err(PyO3ArrowError::from)?; - Ok(PyField { inner: field }) + Ok(Self { inner: field }) } fn to_pyarrow(&self, py: Python) -> PyResult { @@ -125,6 +130,29 @@ impl PyField { } } +#[pymethods] +impl PySchema { + #[staticmethod] + fn from_pyarrow(value: &PyAny) -> PyResult { + let c_schema = FFI_ArrowSchema::empty(); + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?; + let schema = Schema::try_from(&c_schema).map_err(PyO3ArrowError::from)?; + Ok(Self { inner: schema }) + } + + fn to_pyarrow(&self, py: Python) -> PyResult { + let c_schema = + FFI_ArrowSchema::try_from(&self.inner).map_err(PyO3ArrowError::from)?; + let c_schema_ptr = &c_schema as *const FFI_ArrowSchema; + let module = py.import("pyarrow")?; + let class = module.getattr("Schema")?; + let schema = + class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?; + Ok(schema.into()) + } +} + impl<'source> FromPyObject<'source> for PyDataType { fn extract(value: &'source PyAny) -> PyResult { PyDataType::from_pyarrow(value) @@ -137,10 +165,11 @@ impl<'source> FromPyObject<'source> for PyField { } } -// struct PyField(Field); -// struct PySchema(Schema); - -// fn type_to_rust(ob: PyObject, py: Python) -> PyResult { +impl<'source> FromPyObject<'source> for PySchema { + fn extract(value: &'source PyAny) -> PyResult { + PySchema::from_pyarrow(value) + } +} fn array_to_rust(ob: PyObject, py: Python) -> PyResult { // prepare a pointer to receive the Array struct @@ -156,13 +185,12 @@ fn array_to_rust(ob: PyObject, py: Python) -> PyResult { )?; let array = unsafe { make_array_from_raw(array_pointer, schema_pointer) } - .map_err(|e| PyO3ArrowError::from(e))?; + .map_err(PyO3ArrowError::from)?; Ok(array) } fn array_to_py(array: ArrayRef, py: Python) -> PyResult { - let (array_pointer, schema_pointer) = - array.to_raw().map_err(|e| PyO3ArrowError::from(e))?; + let (array_pointer, schema_pointer) = array.to_raw().map_err(PyO3ArrowError::from)?; let pa = py.import("pyarrow")?; @@ -183,15 +211,10 @@ fn double(array: PyObject, py: Python) -> PyResult { let array = array_to_rust(array, py)?; // perform some operation - let array = - array - .as_any() - .downcast_ref::() - .ok_or(PyO3ArrowError::ArrowError(ArrowError::ParseError( - "Expects an int64".to_string(), - )))?; - let array = - kernels::arithmetic::add(&array, &array).map_err(|e| PyO3ArrowError::from(e))?; + let array = array.as_any().downcast_ref::().ok_or_else(|| { + PyO3ArrowError::ArrowError(ArrowError::ParseError("Expects an int64".to_string())) + })?; + let array = kernels::arithmetic::add(&array, &array).map_err(PyO3ArrowError::from)?; let array = Arc::new(array); // export @@ -222,7 +245,7 @@ fn substring(array: PyObject, start: i64, py: Python) -> PyResult { // substring let array = kernels::substring::substring(array.as_ref(), start, &None) - .map_err(|e| PyO3ArrowError::from(e))?; + .map_err(PyO3ArrowError::from)?; // export array_to_py(array, py) @@ -236,7 +259,7 @@ fn concatenate(array: PyObject, py: Python) -> PyResult { // concat let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()]) - .map_err(|e| PyO3ArrowError::from(e))?; + .map_err(PyO3ArrowError::from)?; // export array_to_py(array, py) @@ -256,6 +279,7 @@ fn round_trip(pyarray: PyObject, py: Python) -> PyResult { fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_wrapped(wrap_pyfunction!(double))?; m.add_wrapped(wrap_pyfunction!(double_py))?; m.add_wrapped(wrap_pyfunction!(substring))?; diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index e90cd376e51a..b06653435826 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -17,11 +17,12 @@ # under the License. import contextlib +import string import pytest import pyarrow as pa -from arrow_pyarrow_integration_testing import PyDataType, PyField +from arrow_pyarrow_integration_testing import PyDataType, PyField, PySchema import arrow_pyarrow_integration_testing as rust from pytz import timezone @@ -130,13 +131,24 @@ def test_dictionary_type_roundtrip(): assert ty.to_pyarrow() == pa.int32() -# Missing implementation in pyarrow -# @pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str) -# def test_field_roundtrip(pyarrow_type): -# for nullable in [True, False]: -# pyarrow_field = pa.field("test", pyarrow_type, nullable=nullable) -# field = PyField.from_pyarrow(pyarrow_field) -# assert field.to_pyarrow() == pyarrow_field +@pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str) +def test_field_roundtrip(pyarrow_type): + pyarrow_field = pa.field("test", pyarrow_type, nullable=True) + field = PyField.from_pyarrow(pyarrow_field) + assert field.to_pyarrow() == pyarrow_field + + if pyarrow_type != pa.null(): + # A null type field may not be non-nullable + pyarrow_field = pa.field("test", pyarrow_type, nullable=False) + field = PyField.from_pyarrow(pyarrow_field) + assert field.to_pyarrow() == pyarrow_field + + +def test_schema_roundtrip(): + pyarrow_fields = zip(string.ascii_lowercase, _supported_pyarrow_types) + pyarrow_schema = pa.schema(pyarrow_fields) + schema = PySchema.from_pyarrow(pyarrow_schema) + assert schema.to_pyarrow() == pyarrow_schema def test_primitive_python(): diff --git a/arrow/src/array/ffi.rs b/arrow/src/array/ffi.rs index a404186a8f18..847649ce1264 100644 --- a/arrow/src/array/ffi.rs +++ b/arrow/src/array/ffi.rs @@ -22,7 +22,7 @@ use std::convert::TryFrom; use crate::{ error::{ArrowError, Result}, ffi, - ffi::ArrowArrayRef + ffi::ArrowArrayRef, }; use super::ArrayData; diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs index 489128991e7f..30a31998abef 100644 --- a/arrow/src/datatypes/ffi.rs +++ b/arrow/src/datatypes/ffi.rs @@ -159,9 +159,9 @@ impl TryFrom<&FFI_ArrowSchema> for Schema { if let DataType::Struct(fields) = dtype { Ok(Schema::new(fields)) } else { - Err(ArrowError::CDataInterface(format!( - "Unable to interpret C data struct as a Schema" - ))) + Err(ArrowError::CDataInterface( + "Unable to interpret C data struct as a Schema".to_string(), + )) } } } @@ -306,10 +306,6 @@ mod tests { Ok(()) } - // fn roundtrip>(expected: T) { - // let c_schema: FFI_ArrowSchema = expected.try_into().unwrap(); - // } - #[test] fn test_type() -> Result<()> { round_trip_type(DataType::Int64)?; @@ -362,6 +358,4 @@ mod tests { assert_eq!(result.is_err(), true); Ok(()) } - - // TODO } From 6ad1edd5921cabffe11124f1ed215746113654fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Thu, 17 Jun 2021 17:43:37 +0200 Subject: [PATCH 08/12] Cleanup --- arrow-pyarrow-integration-testing/src/lib.rs | 3 --- arrow/src/datatypes/ffi.rs | 2 -- 2 files changed, 5 deletions(-) diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs index 13dee1d691a3..a601654d0bcd 100644 --- a/arrow-pyarrow-integration-testing/src/lib.rs +++ b/arrow-pyarrow-integration-testing/src/lib.rs @@ -201,9 +201,6 @@ fn array_to_py(array: ArrayRef, py: Python) -> PyResult { Ok(array.to_object(py)) } -/// Casts `array` to the target type -//#[pyfunction] - /// Returns `array + array` of an int64 array. #[pyfunction] fn double(array: PyObject, py: Python) -> PyResult { diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs index 30a31998abef..3807edeba9d0 100644 --- a/arrow/src/datatypes/ffi.rs +++ b/arrow/src/datatypes/ffi.rs @@ -15,8 +15,6 @@ // specific language governing permissions and limitations // under the License. -//! Contains functionality to load an ArrayData from the C Data Interface - use std::convert::TryFrom; use crate::{ From f50339c685b0b907fb98b1eb0c99a125b5b1f08c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Mon, 21 Jun 2021 18:02:11 +0200 Subject: [PATCH 09/12] Remove comment --- arrow/src/ffi.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index f1f208e089fe..a33e72c53657 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -152,7 +152,6 @@ impl FFI_ArrowSchema { pub fn try_new(format: &str, children: Vec) -> Result { let mut this = Self::empty(); - // note: this op leaks. let mut children_ptr = children .into_iter() .map(Box::new) From 018063ffef14a27cf2c4e7b8549ccd416a8c0bdd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Mon, 21 Jun 2021 18:09:12 +0200 Subject: [PATCH 10/12] cleanup --- arrow/src/ffi.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs index a33e72c53657..e3589cacdd43 100644 --- a/arrow/src/ffi.rs +++ b/arrow/src/ffi.rs @@ -181,10 +181,6 @@ impl FFI_ArrowSchema { Ok(self) } - // pub fn with_dictionary() {} - // pub fn with_metadata() {} - - /// create an empty [FFI_ArrowSchema] pub fn empty() -> Self { Self { format: std::ptr::null_mut(), @@ -223,7 +219,6 @@ impl FFI_ArrowSchema { pub fn child(&self, index: usize) -> &Self { assert!(index < self.n_children as usize); - // assert!(!self.name.is_null()); unsafe { self.children.add(index).as_ref().unwrap().as_ref().unwrap() } } From fe7258e89c535d479385780359501e174ae0190e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kriszti=C3=A1n=20Sz=C5=B1cs?= Date: Wed, 23 Jun 2021 15:25:46 +0200 Subject: [PATCH 11/12] Fix python tests after rebase --- .../tests/test_sql.py | 76 +++++++++---------- 1 file changed, 37 insertions(+), 39 deletions(-) diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py index b06653435826..301eac8d2a09 100644 --- a/arrow-pyarrow-integration-testing/tests/test_sql.py +++ b/arrow-pyarrow-integration-testing/tests/test_sql.py @@ -17,16 +17,17 @@ # under the License. import contextlib +import datetime +import decimal import string import pytest import pyarrow as pa +import pytz from arrow_pyarrow_integration_testing import PyDataType, PyField, PySchema import arrow_pyarrow_integration_testing as rust -from pytz import timezone - @contextlib.contextmanager def no_pyarrow_leak(): @@ -52,9 +53,13 @@ def assert_pyarrow_leak(): pa.time32("s"), pa.time64("us"), pa.date32(), + pa.timestamp("us"), + pa.timestamp("us", tz="UTC"), + pa.timestamp("us", tz="Europe/Paris"), pa.float16(), pa.float32(), pa.float64(), + pa.decimal128(19, 4), pa.string(), pa.binary(), pa.large_string(), @@ -78,12 +83,8 @@ def assert_pyarrow_leak(): ] _unsupported_pyarrow_types = [ - pa.timestamp("us"), - pa.timestamp("us", tz="UTC"), - pa.timestamp("us", tz="Europe/Paris"), - pa.duration("s"), - pa.decimal128(19, 4), pa.decimal256(76, 38), + pa.duration("s"), pa.binary(10), pa.list_(pa.int32(), 2), pa.map_(pa.string(), pa.int32()), @@ -212,58 +213,55 @@ def test_list_array(): del b -def test_timestamp_python(self): +def test_timestamp_python(): """ Python -> Rust -> Python """ - old_allocated = pyarrow.total_allocated_bytes() - py_array = [ + data = [ None, - datetime(2021, 1, 1, 1, 1, 1, 1), - datetime(2020, 3, 9, 1, 1, 1, 1), + datetime.datetime(2021, 1, 1, 1, 1, 1, 1), + datetime.datetime(2020, 3, 9, 1, 1, 1, 1), ] - a = pyarrow.array(py_array, pyarrow.timestamp("us")) - b = arrow_pyarrow_integration_testing.concatenate(a) - expected = pyarrow.array(py_array + py_array, pyarrow.timestamp("us")) - self.assertEqual(b, expected) + a = pa.array(data, pa.timestamp("us")) + b = rust.concatenate(a) + expected = pa.array(data + data, pa.timestamp("us")) + assert b == expected del a del b del expected - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) -def test_timestamp_tz_python(self): + +def test_timestamp_tz_python(): """ Python -> Rust -> Python """ - old_allocated = pyarrow.total_allocated_bytes() - py_array = [ + tzinfo = pytz.timezone("America/New_York") + pyarrow_type = pa.timestamp("us", tz="America/New_York") + data = [ None, - datetime(2021, 1, 1, 1, 1, 1, 1, tzinfo=timezone("America/New_York")), - datetime(2020, 3, 9, 1, 1, 1, 1, tzinfo=timezone("America/New_York")), + datetime.datetime(2021, 1, 1, 1, 1, 1, 1, tzinfo=tzinfo), + datetime.datetime(2020, 3, 9, 1, 1, 1, 1, tzinfo=tzinfo), ] - a = pyarrow.array(py_array, pyarrow.timestamp("us", tz="America/New_York")) - b = arrow_pyarrow_integration_testing.concatenate(a) - expected = pyarrow.array( - py_array + py_array, pyarrow.timestamp("us", tz="America/New_York") - ) - self.assertEqual(b, expected) + a = pa.array(data, type=pyarrow_type) + b = rust.concatenate(a) + expected = pa.array(data * 2, type=pyarrow_type) + assert b == expected del a del b del expected - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) -def test_decimal_python(self): + +def test_decimal_python(): """ Python -> Rust -> Python """ - old_allocated = pyarrow.total_allocated_bytes() - py_array = [round(Decimal(123.45), 2), round(Decimal(-123.45), 2), None] - a = pyarrow.array(py_array, pyarrow.decimal128(6, 2)) - b = arrow_pyarrow_integration_testing.round_trip(a) - self.assertEqual(a, b) + data = [ + round(decimal.Decimal(123.45), 2), + round(decimal.Decimal(-123.45), 2), + None + ] + a = pa.array(data, pa.decimal128(6, 2)) + b = rust.round_trip(a) + assert a == b del a del b - # No leak of C++ memory - self.assertEqual(old_allocated, pyarrow.total_allocated_bytes()) From 7acd3a5536b723ac2136a2f4eb8cb118605a40a2 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 1 Jul 2021 17:30:26 -0400 Subject: [PATCH 12/12] fix clippy --- arrow/src/datatypes/ffi.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs index 3807edeba9d0..7e98508cf090 100644 --- a/arrow/src/datatypes/ffi.rs +++ b/arrow/src/datatypes/ffi.rs @@ -353,7 +353,7 @@ mod tests { // test that we assert the input type let c_schema = FFI_ArrowSchema::try_from(&DataType::Float64)?; let result = Schema::try_from(&c_schema); - assert_eq!(result.is_err(), true); + assert!(result.is_err()); Ok(()) } }