Skip to content

Commit

Permalink
Add C data interface for decimal128 and timestamp (#453) (#495)
Browse files Browse the repository at this point in the history
* Add C data interface for decimal128

* Add timestamp support to C data interface

Add extra date32 inttegration test

Co-authored-by: Ádám Lippai <[email protected]>
  • Loading branch information
alamb and alippai authored Jun 23, 2021
1 parent c1f9083 commit c0f06d4
Show file tree
Hide file tree
Showing 3 changed files with 237 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,7 @@ jobs:
python -m venv venv
source venv/bin/activate
pip install maturin==0.8.2 toml==0.10.1 pyarrow==1.0.0
pip install maturin==0.8.2 toml==0.10.1 pyarrow==1.0.0 pytz
maturin develop
python -m unittest discover tests
Expand Down
84 changes: 78 additions & 6 deletions arrow-pyarrow-integration-testing/tests/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
# under the License.

import unittest
from datetime import date, datetime
from decimal import Decimal

import pyarrow
import arrow_pyarrow_integration_testing
import pyarrow
from pytz import timezone


class TestCase(unittest.TestCase):
Expand Down Expand Up @@ -69,16 +72,88 @@ def test_time32_python(self):
Python -> Rust -> Python
"""
old_allocated = pyarrow.total_allocated_bytes()
a = pyarrow.array([None, 1, 2], pyarrow.time32('s'))
a = pyarrow.array([None, 1, 2], pyarrow.time32("s"))
b = arrow_pyarrow_integration_testing.concatenate(a)
expected = pyarrow.array([None, 1, 2] + [None, 1, 2], pyarrow.time32("s"))
self.assertEqual(b, expected)
del a
del b
del expected
# No leak of C++ memory
self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())

def test_date32_python(self):
"""
Python -> Rust -> Python
"""
old_allocated = pyarrow.total_allocated_bytes()
py_array = [None, date(1990, 3, 9), date(2021, 6, 20)]
a = pyarrow.array(py_array, pyarrow.date32())
b = arrow_pyarrow_integration_testing.concatenate(a)
expected = pyarrow.array(py_array + py_array, pyarrow.date32())
self.assertEqual(b, expected)
del a
del b
del expected
# No leak of C++ memory
self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())

def test_timestamp_python(self):
"""
Python -> Rust -> Python
"""
old_allocated = pyarrow.total_allocated_bytes()
py_array = [
None,
datetime(2021, 1, 1, 1, 1, 1, 1),
datetime(2020, 3, 9, 1, 1, 1, 1),
]
a = pyarrow.array(py_array, pyarrow.timestamp("us"))
b = arrow_pyarrow_integration_testing.concatenate(a)
expected = pyarrow.array([None, 1, 2] + [None, 1, 2], pyarrow.time32('s'))
expected = pyarrow.array(py_array + py_array, pyarrow.timestamp("us"))
self.assertEqual(b, expected)
del a
del b
del expected
# No leak of C++ memory
self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())

def test_timestamp_tz_python(self):
"""
Python -> Rust -> Python
"""
old_allocated = pyarrow.total_allocated_bytes()
py_array = [
None,
datetime(2021, 1, 1, 1, 1, 1, 1, tzinfo=timezone("America/New_York")),
datetime(2020, 3, 9, 1, 1, 1, 1, tzinfo=timezone("America/New_York")),
]
a = pyarrow.array(py_array, pyarrow.timestamp("us", tz="America/New_York"))
b = arrow_pyarrow_integration_testing.concatenate(a)
expected = pyarrow.array(
py_array + py_array, pyarrow.timestamp("us", tz="America/New_York")
)
self.assertEqual(b, expected)
del a
del b
del expected
# No leak of C++ memory
self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())

def test_decimal_python(self):
"""
Python -> Rust -> Python
"""
old_allocated = pyarrow.total_allocated_bytes()
py_array = [round(Decimal(123.45), 2), round(Decimal(-123.45), 2), None]
a = pyarrow.array(py_array, pyarrow.decimal128(6, 2))
b = arrow_pyarrow_integration_testing.round_trip(a)
self.assertEqual(a, b)
del a
del b
# No leak of C++ memory
self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())

def test_list_array(self):
"""
Python -> Rust -> Python
Expand All @@ -94,6 +169,3 @@ def test_list_array(self):
del b
# No leak of C++ memory
self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())



166 changes: 158 additions & 8 deletions arrow/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,11 +268,75 @@ fn to_field(schema: &FFI_ArrowSchema) -> Result<Field> {
.collect::<Result<Vec<_>>>()?;
DataType::Struct(children)
}
// Parametrized types, requiring string parse
other => {
return Err(ArrowError::CDataInterface(format!(
"The datatype \"{:?}\" is still not supported in Rust implementation",
other
)))
match other.splitn(2, ':').collect::<Vec<&str>>().as_slice() {
// Decimal types in format "d:precision,scale" or "d:precision,scale,bitWidth"
["d", extra] => {
match extra.splitn(3, ',').collect::<Vec<&str>>().as_slice() {
[precision, scale] => {
let parsed_precision = precision.parse::<usize>().map_err(|_| {
ArrowError::CDataInterface(
"The decimal type requires an integer precision".to_string(),
)
})?;
let parsed_scale = scale.parse::<usize>().map_err(|_| {
ArrowError::CDataInterface(
"The decimal type requires an integer scale".to_string(),
)
})?;
DataType::Decimal(parsed_precision, parsed_scale)
},
[precision, scale, bits] => {
if *bits != "128" {
return Err(ArrowError::CDataInterface("Only 128 bit wide decimal is supported in the Rust implementation".to_string()));
}
let parsed_precision = precision.parse::<usize>().map_err(|_| {
ArrowError::CDataInterface(
"The decimal type requires an integer precision".to_string(),
)
})?;
let parsed_scale = scale.parse::<usize>().map_err(|_| {
ArrowError::CDataInterface(
"The decimal type requires an integer scale".to_string(),
)
})?;
DataType::Decimal(parsed_precision, parsed_scale)
}
_ => {
return Err(ArrowError::CDataInterface(format!(
"The decimal pattern \"d:{:?}\" is not supported in the Rust implementation",
extra
)))
}
}
}

// Timestamps in format "tts:" and "tts:America/New_York" for no timezones and timezones resp.
["tss", ""] => DataType::Timestamp(TimeUnit::Second, None),
["tsm", ""] => DataType::Timestamp(TimeUnit::Millisecond, None),
["tsu", ""] => DataType::Timestamp(TimeUnit::Microsecond, None),
["tsn", ""] => DataType::Timestamp(TimeUnit::Nanosecond, None),
["tss", tz] => {
DataType::Timestamp(TimeUnit::Second, Some(tz.to_string()))
}
["tsm", tz] => {
DataType::Timestamp(TimeUnit::Millisecond, Some(tz.to_string()))
}
["tsu", tz] => {
DataType::Timestamp(TimeUnit::Microsecond, Some(tz.to_string()))
}
["tsn", tz] => {
DataType::Timestamp(TimeUnit::Nanosecond, Some(tz.to_string()))
}

_ => {
return Err(ArrowError::CDataInterface(format!(
"The datatype \"{:?}\" is still not supported in Rust implementation",
other
)))
}
}
}
};
Ok(Field::new(schema.name(), data_type, schema.nullable()))
Expand All @@ -298,12 +362,31 @@ fn to_format(data_type: &DataType) -> Result<String> {
DataType::LargeBinary => "Z",
DataType::Utf8 => "u",
DataType::LargeUtf8 => "U",
DataType::Decimal(precision, scale) => {
return Ok(format!("d:{},{}", precision, scale))
}
DataType::Date32 => "tdD",
DataType::Date64 => "tdm",
DataType::Time32(TimeUnit::Second) => "tts",
DataType::Time32(TimeUnit::Millisecond) => "ttm",
DataType::Time64(TimeUnit::Microsecond) => "ttu",
DataType::Time64(TimeUnit::Nanosecond) => "ttn",
DataType::Timestamp(TimeUnit::Second, None) => "tss:",
DataType::Timestamp(TimeUnit::Millisecond, None) => "tsm:",
DataType::Timestamp(TimeUnit::Microsecond, None) => "tsu:",
DataType::Timestamp(TimeUnit::Nanosecond, None) => "tsn:",
DataType::Timestamp(TimeUnit::Second, Some(tz)) => {
return Ok(format!("tss:{}", tz))
}
DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => {
return Ok(format!("tsm:{}", tz))
}
DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => {
return Ok(format!("tsu:{}", tz))
}
DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => {
return Ok(format!("tsn:{}", tz))
}
DataType::List(_) => "+l",
DataType::LargeList(_) => "+L",
DataType::Struct(_) => "+s",
Expand Down Expand Up @@ -335,6 +418,8 @@ fn bit_width(data_type: &DataType, i: usize) -> Result<usize> {
(DataType::Int64, 1) | (DataType::Date64, 1) | (DataType::Time64(_), 1) => size_of::<i64>() * 8,
(DataType::Float32, 1) => size_of::<f32>() * 8,
(DataType::Float64, 1) => size_of::<f64>() * 8,
(DataType::Decimal(..), 1) => size_of::<i128>() * 8,
(DataType::Timestamp(..), 1) => size_of::<i64>() * 8,
// primitive types have a single buffer
(DataType::Boolean, _) |
(DataType::UInt8, _) |
Expand All @@ -346,7 +431,9 @@ fn bit_width(data_type: &DataType, i: usize) -> Result<usize> {
(DataType::Int32, _) | (DataType::Date32, _) | (DataType::Time32(_), _) |
(DataType::Int64, _) | (DataType::Date64, _) | (DataType::Time64(_), _) |
(DataType::Float32, _) |
(DataType::Float64, _) => {
(DataType::Float64, _) |
(DataType::Decimal(..), _) |
(DataType::Timestamp(..), _) => {
return Err(ArrowError::CDataInterface(format!(
"The datatype \"{:?}\" expects 2 buffers, but requested {}. Please verify that the C data interface is correctly implemented.",
data_type, i
Expand Down Expand Up @@ -828,9 +915,10 @@ impl<'a> ArrowArrayChild<'a> {
mod tests {
use super::*;
use crate::array::{
make_array, Array, ArrayData, BinaryOffsetSizeTrait, BooleanArray,
GenericBinaryArray, GenericListArray, GenericStringArray, Int32Array,
OffsetSizeTrait, StringOffsetSizeTrait, Time32MillisecondArray,
make_array, Array, ArrayData, BinaryOffsetSizeTrait, BooleanArray, DecimalArray,
DecimalBuilder, GenericBinaryArray, GenericListArray, GenericStringArray,
Int32Array, OffsetSizeTrait, StringOffsetSizeTrait, Time32MillisecondArray,
TimestampMillisecondArray,
};
use crate::compute::kernels;
use crate::datatypes::Field;
Expand Down Expand Up @@ -858,6 +946,32 @@ mod tests {
// (drop/release)
Ok(())
}

#[test]
fn test_decimal_round_trip() -> Result<()> {
// create an array natively
let mut builder = DecimalBuilder::new(5, 6, 2);
builder.append_value(12345_i128).unwrap();
builder.append_value(-12345_i128).unwrap();
builder.append_null().unwrap();
let original_array = builder.finish();

// export it
let array = ArrowArray::try_from(original_array.data().clone())?;

// (simulate consumer) import it
let data = ArrayData::try_from(array)?;
let array = make_array(data);

// perform some operation
let array = array.as_any().downcast_ref::<DecimalArray>().unwrap();

// verify
assert_eq!(array, &original_array);

// (drop/release)
Ok(())
}
// case with nulls is tested in the docs, through the example on this module.

fn test_generic_string<Offset: StringOffsetSizeTrait>() -> Result<()> {
Expand Down Expand Up @@ -1076,4 +1190,40 @@ mod tests {
// (drop/release)
Ok(())
}

#[test]
fn test_timestamp() -> Result<()> {
// create an array natively
let array = TimestampMillisecondArray::from(vec![None, Some(1), Some(2)]);

// export it
let array = ArrowArray::try_from(array.data().clone())?;

// (simulate consumer) import it
let data = ArrayData::try_from(array)?;
let array = make_array(data);

// perform some operation
let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()]).unwrap();
let array = array
.as_any()
.downcast_ref::<TimestampMillisecondArray>()
.unwrap();

// verify
assert_eq!(
array,
&TimestampMillisecondArray::from(vec![
None,
Some(1),
Some(2),
None,
Some(1),
Some(2)
])
);

// (drop/release)
Ok(())
}
}

0 comments on commit c0f06d4

Please sign in to comment.