From a81ac6efb89c5c13f058e7821088b020c7cbd5bc Mon Sep 17 00:00:00 2001 From: Kyle Barron Date: Wed, 25 Sep 2024 17:39:58 -0400 Subject: [PATCH] Arith functions (#194) * Arith functions * add simple test --- arro3-compute/python/arro3/compute/_arith.pyi | 41 ++++++++++++++ .../python/arro3/compute/_compute.pyi | 10 ++++ arro3-compute/src/arith.rs | 55 +++++++++++++++++++ arro3-compute/src/lib.rs | 11 ++++ arro3-core/python/arro3/core/_core.pyi | 16 ++++++ pyo3-arrow/src/ffi/from_python/input.rs | 16 +++++- pyo3-arrow/src/input.rs | 33 ++++++++++- pyo3-arrow/src/scalar.rs | 17 +++++- tests/compute/test_arith.py | 18 ++++++ 9 files changed, 212 insertions(+), 5 deletions(-) create mode 100644 arro3-compute/python/arro3/compute/_arith.pyi create mode 100644 arro3-compute/src/arith.rs create mode 100644 tests/compute/test_arith.py diff --git a/arro3-compute/python/arro3/compute/_arith.pyi b/arro3-compute/python/arro3/compute/_arith.pyi new file mode 100644 index 0000000..7e2c958 --- /dev/null +++ b/arro3-compute/python/arro3/compute/_arith.pyi @@ -0,0 +1,41 @@ +# Note: importing with +# `from arro3.core import Array` +# will cause Array to be included in the generated docs in this module. +import arro3.core as core +import arro3.core.types as types + +def add(lhs: types.ArrowArrayExportable, rhs: types.ArrowArrayExportable) -> core.Array: + """Perform `lhs + rhs`, returning an error on overflow""" + +def add_wrapping( + lhs: types.ArrowArrayExportable, rhs: types.ArrowArrayExportable +) -> core.Array: + """Perform `lhs + rhs`, wrapping on overflow for DataType::is_integer""" + +def div(lhs: types.ArrowArrayExportable, rhs: types.ArrowArrayExportable) -> core.Array: + """Perform `lhs / rhs`""" + +def mul(lhs: types.ArrowArrayExportable, rhs: types.ArrowArrayExportable) -> core.Array: + """Perform `lhs * rhs`, returning an error on overflow""" + +def mul_wrapping( + lhs: types.ArrowArrayExportable, rhs: types.ArrowArrayExportable +) -> core.Array: + """Perform `lhs * rhs`, wrapping on overflow for DataType::is_integer""" + +def neg(array: types.ArrowArrayExportable) -> core.Array: + """Negates each element of array, returning an error on overflow""" + +def neg_wrapping(array: types.ArrowArrayExportable) -> core.Array: + """Negates each element of array, wrapping on overflow for DataType::is_integer""" + +def rem(lhs: types.ArrowArrayExportable, rhs: types.ArrowArrayExportable) -> core.Array: + """Perform `lhs % rhs`""" + +def sub(lhs: types.ArrowArrayExportable, rhs: types.ArrowArrayExportable) -> core.Array: + """Perform `lhs - rhs`, returning an error on overflow""" + +def sub_wrapping( + lhs: types.ArrowArrayExportable, rhs: types.ArrowArrayExportable +) -> core.Array: + """Perform `lhs - rhs`, wrapping on overflow for DataType::is_integer""" diff --git a/arro3-compute/python/arro3/compute/_compute.pyi b/arro3-compute/python/arro3/compute/_compute.pyi index 6b1d083..1721b3f 100644 --- a/arro3-compute/python/arro3/compute/_compute.pyi +++ b/arro3-compute/python/arro3/compute/_compute.pyi @@ -5,6 +5,16 @@ from typing import overload # will cause Array to be included in the generated docs in this module. import arro3.core as core import arro3.core.types as types +from arro3.compute._arith import add as add +from arro3.compute._arith import add_wrapping as add_wrapping +from arro3.compute._arith import div as div +from arro3.compute._arith import mul as mul +from arro3.compute._arith import mul_wrapping as mul_wrapping +from arro3.compute._arith import neg as neg +from arro3.compute._arith import neg_wrapping as neg_wrapping +from arro3.compute._arith import rem as rem +from arro3.compute._arith import sub as sub +from arro3.compute._arith import sub_wrapping as sub_wrapping @overload def cast( diff --git a/arro3-compute/src/arith.rs b/arro3-compute/src/arith.rs new file mode 100644 index 0000000..c0a4c01 --- /dev/null +++ b/arro3-compute/src/arith.rs @@ -0,0 +1,55 @@ +use arrow::compute::kernels::numeric; +use pyo3::prelude::*; +use pyo3_arrow::error::PyArrowResult; +use pyo3_arrow::input::AnyDatum; +use pyo3_arrow::PyArray; + +#[pyfunction] +pub fn add(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult { + Ok(PyArray::from_array_ref(numeric::add(&lhs, &rhs)?).to_arro3(py)?) +} + +#[pyfunction] +pub fn add_wrapping(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult { + Ok(PyArray::from_array_ref(numeric::add_wrapping(&lhs, &rhs)?).to_arro3(py)?) +} + +#[pyfunction] +pub fn div(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult { + Ok(PyArray::from_array_ref(numeric::div(&lhs, &rhs)?).to_arro3(py)?) +} + +#[pyfunction] +pub fn mul(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult { + Ok(PyArray::from_array_ref(numeric::mul(&lhs, &rhs)?).to_arro3(py)?) +} + +#[pyfunction] +pub fn mul_wrapping(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult { + Ok(PyArray::from_array_ref(numeric::mul_wrapping(&lhs, &rhs)?).to_arro3(py)?) +} + +#[pyfunction] +pub fn neg(py: Python, array: PyArray) -> PyArrowResult { + Ok(PyArray::from_array_ref(numeric::neg(array.as_ref())?).to_arro3(py)?) +} + +#[pyfunction] +pub fn neg_wrapping(py: Python, array: PyArray) -> PyArrowResult { + Ok(PyArray::from_array_ref(numeric::neg_wrapping(array.as_ref())?).to_arro3(py)?) +} + +#[pyfunction] +pub fn rem(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult { + Ok(PyArray::from_array_ref(numeric::rem(&lhs, &rhs)?).to_arro3(py)?) +} + +#[pyfunction] +pub fn sub(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult { + Ok(PyArray::from_array_ref(numeric::sub(&lhs, &rhs)?).to_arro3(py)?) +} + +#[pyfunction] +pub fn sub_wrapping(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult { + Ok(PyArray::from_array_ref(numeric::sub_wrapping(&lhs, &rhs)?).to_arro3(py)?) +} diff --git a/arro3-compute/src/lib.rs b/arro3-compute/src/lib.rs index a3a61fd..0532411 100644 --- a/arro3-compute/src/lib.rs +++ b/arro3-compute/src/lib.rs @@ -1,6 +1,7 @@ use pyo3::prelude::*; mod aggregate; +mod arith; mod cast; mod concat; mod dictionary; @@ -20,6 +21,16 @@ fn _compute(_py: Python, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(aggregate::max))?; m.add_wrapped(wrap_pyfunction!(aggregate::min))?; m.add_wrapped(wrap_pyfunction!(aggregate::sum))?; + m.add_wrapped(wrap_pyfunction!(arith::add_wrapping))?; + m.add_wrapped(wrap_pyfunction!(arith::add))?; + m.add_wrapped(wrap_pyfunction!(arith::div))?; + m.add_wrapped(wrap_pyfunction!(arith::mul_wrapping))?; + m.add_wrapped(wrap_pyfunction!(arith::mul))?; + m.add_wrapped(wrap_pyfunction!(arith::neg_wrapping))?; + m.add_wrapped(wrap_pyfunction!(arith::neg))?; + m.add_wrapped(wrap_pyfunction!(arith::rem))?; + m.add_wrapped(wrap_pyfunction!(arith::sub_wrapping))?; + m.add_wrapped(wrap_pyfunction!(arith::sub))?; m.add_wrapped(wrap_pyfunction!(cast::cast))?; m.add_wrapped(wrap_pyfunction!(concat::concat))?; m.add_wrapped(wrap_pyfunction!(dictionary::dictionary_dictionary))?; diff --git a/arro3-core/python/arro3/core/_core.pyi b/arro3-core/python/arro3/core/_core.pyi index 06a2eb2..d3c61e8 100644 --- a/arro3-core/python/arro3/core/_core.pyi +++ b/arro3-core/python/arro3/core/_core.pyi @@ -1176,7 +1176,23 @@ class RecordBatchReader: class Scalar: """An arrow Scalar.""" + def __arrow_c_array__( + self, requested_schema: object | None = None + ) -> tuple[object, object]: + """ + An implementation of the [Arrow PyCapsule + Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html). + This dunder method should not be called directly, but enables zero-copy data + transfer to other Python libraries that understand Arrow memory. + + For example, you can call [`pyarrow.array()`][pyarrow.array] to + convert this Scalar into a pyarrow Array, without copying memory. The generated + array is guaranteed to have length 1. + """ def __repr__(self) -> str: ... + @classmethod + def from_arrow_pycapsule(cls, schema_capsule, array_capsule) -> Scalar: + """Construct this object from bare Arrow PyCapsules""" def as_py(self) -> Any: ... @property def is_valid(self) -> bool: ... diff --git a/pyo3-arrow/src/ffi/from_python/input.rs b/pyo3-arrow/src/ffi/from_python/input.rs index f9e2007..f9db8ff 100644 --- a/pyo3-arrow/src/ffi/from_python/input.rs +++ b/pyo3-arrow/src/ffi/from_python/input.rs @@ -1,6 +1,6 @@ use crate::array_reader::PyArrayReader; -use crate::input::{AnyArray, AnyRecordBatch}; -use crate::{PyArray, PyRecordBatch, PyRecordBatchReader}; +use crate::input::{AnyArray, AnyDatum, AnyRecordBatch}; +use crate::{PyArray, PyRecordBatch, PyRecordBatchReader, PyScalar}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::{PyAny, PyResult}; @@ -32,3 +32,15 @@ impl<'a> FromPyObject<'a> for AnyArray { } } } + +impl<'a> FromPyObject<'a> for AnyDatum { + fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult { + let array = PyArray::extract_bound(ob)?; + if array.as_ref().len() == 1 { + let (array, field) = array.into_inner(); + Ok(Self::Scalar(PyScalar::try_new(array, field)?)) + } else { + Ok(Self::Array(array)) + } + } +} diff --git a/pyo3-arrow/src/input.rs b/pyo3-arrow/src/input.rs index 047e565..1006dd7 100644 --- a/pyo3-arrow/src/input.rs +++ b/pyo3-arrow/src/input.rs @@ -7,7 +7,7 @@ use std::collections::HashMap; use std::string::FromUtf8Error; use std::sync::Arc; -use arrow_array::{RecordBatchIterator, RecordBatchReader}; +use arrow_array::{Datum, RecordBatchIterator, RecordBatchReader}; use arrow_schema::{ArrowError, Field, FieldRef, Fields, Schema, SchemaRef}; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; @@ -15,7 +15,9 @@ use pyo3::prelude::*; use crate::array_reader::PyArrayReader; use crate::error::PyArrowResult; use crate::ffi::{ArrayIterator, ArrayReader}; -use crate::{PyArray, PyChunkedArray, PyField, PyRecordBatch, PyRecordBatchReader, PyTable}; +use crate::{ + PyArray, PyChunkedArray, PyField, PyRecordBatch, PyRecordBatchReader, PyScalar, PyTable, +}; /// An enum over [PyRecordBatch] and [PyRecordBatchReader], used when a function accepts either /// Arrow object as input. @@ -98,6 +100,33 @@ impl AnyArray { } } +/// An enum over [PyArray] and [PyScalar], used for functions that accept +pub enum AnyDatum { + /// A single Array, held in a [PyArray]. + Array(PyArray), + /// An Arrow Scalar, held in a [pyScalar] + Scalar(PyScalar), +} + +impl AnyDatum { + /// Access the field of this object. + pub fn field(&self) -> &FieldRef { + match self { + Self::Array(inner) => inner.field(), + Self::Scalar(inner) => inner.field(), + } + } +} + +impl Datum for AnyDatum { + fn get(&self) -> (&dyn arrow_array::Array, bool) { + match self { + Self::Array(inner) => inner.get(), + Self::Scalar(inner) => inner.get(), + } + } +} + #[derive(FromPyObject)] pub(crate) enum MetadataInput { String(HashMap), diff --git a/pyo3-arrow/src/scalar.rs b/pyo3-arrow/src/scalar.rs index c6ca56f..ea90019 100644 --- a/pyo3-arrow/src/scalar.rs +++ b/pyo3-arrow/src/scalar.rs @@ -67,7 +67,22 @@ impl PyScalar { Self::try_new(array, field) } - /// Export to an arro3.core.Array. + /// Access the underlying [ArrayRef]. + pub fn array(&self) -> &ArrayRef { + &self.array + } + + /// Access the underlying [FieldRef]. + pub fn field(&self) -> &FieldRef { + &self.field + } + + /// Consume self to access the underlying [ArrayRef] and [FieldRef]. + pub fn into_inner(self) -> (ArrayRef, FieldRef) { + (self.array, self.field) + } + + /// Export to an arro3.core.Scalar. /// /// This requires that you depend on arro3-core from your Python package. pub fn to_arro3(&self, py: Python) -> PyResult { diff --git a/tests/compute/test_arith.py b/tests/compute/test_arith.py new file mode 100644 index 0000000..df3d9d4 --- /dev/null +++ b/tests/compute/test_arith.py @@ -0,0 +1,18 @@ +import arro3.compute as ac +import pyarrow as pa +from arro3.core import Array, DataType + + +def test_add(): + arr1 = Array([1, 2, 3], DataType.int16()) + assert ac.min(arr1).as_py() == 1 + + arr2 = Array([3, 2, 0], DataType.int16()) + assert ac.min(arr2).as_py() == 0 + + add1 = ac.add(arr1, arr2) + assert pa.array(add1) == pa.array(Array([4, 4, 3], DataType.int16())) + + s = arr1[0] + add2 = ac.add(arr1, s) + assert pa.array(add2) == pa.array(Array([2, 3, 4], DataType.int16()))