Skip to content

Commit

Permalink
Arith functions (#194)
Browse files Browse the repository at this point in the history
* Arith functions

* add simple test
  • Loading branch information
kylebarron authored Sep 25, 2024
1 parent 6798e60 commit a81ac6e
Show file tree
Hide file tree
Showing 9 changed files with 212 additions and 5 deletions.
41 changes: 41 additions & 0 deletions arro3-compute/python/arro3/compute/_arith.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Note: importing with
# `from arro3.core import Array`
# will cause Array to be included in the generated docs in this module.
import arro3.core as core
import arro3.core.types as types

def add(lhs: types.ArrowArrayExportable, rhs: types.ArrowArrayExportable) -> core.Array:
"""Perform `lhs + rhs`, returning an error on overflow"""

def add_wrapping(
lhs: types.ArrowArrayExportable, rhs: types.ArrowArrayExportable
) -> core.Array:
"""Perform `lhs + rhs`, wrapping on overflow for DataType::is_integer"""

def div(lhs: types.ArrowArrayExportable, rhs: types.ArrowArrayExportable) -> core.Array:
"""Perform `lhs / rhs`"""

def mul(lhs: types.ArrowArrayExportable, rhs: types.ArrowArrayExportable) -> core.Array:
"""Perform `lhs * rhs`, returning an error on overflow"""

def mul_wrapping(
lhs: types.ArrowArrayExportable, rhs: types.ArrowArrayExportable
) -> core.Array:
"""Perform `lhs * rhs`, wrapping on overflow for DataType::is_integer"""

def neg(array: types.ArrowArrayExportable) -> core.Array:
"""Negates each element of array, returning an error on overflow"""

def neg_wrapping(array: types.ArrowArrayExportable) -> core.Array:
"""Negates each element of array, wrapping on overflow for DataType::is_integer"""

def rem(lhs: types.ArrowArrayExportable, rhs: types.ArrowArrayExportable) -> core.Array:
"""Perform `lhs % rhs`"""

def sub(lhs: types.ArrowArrayExportable, rhs: types.ArrowArrayExportable) -> core.Array:
"""Perform `lhs - rhs`, returning an error on overflow"""

def sub_wrapping(
lhs: types.ArrowArrayExportable, rhs: types.ArrowArrayExportable
) -> core.Array:
"""Perform `lhs - rhs`, wrapping on overflow for DataType::is_integer"""
10 changes: 10 additions & 0 deletions arro3-compute/python/arro3/compute/_compute.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,16 @@ from typing import overload
# will cause Array to be included in the generated docs in this module.
import arro3.core as core
import arro3.core.types as types
from arro3.compute._arith import add as add
from arro3.compute._arith import add_wrapping as add_wrapping
from arro3.compute._arith import div as div
from arro3.compute._arith import mul as mul
from arro3.compute._arith import mul_wrapping as mul_wrapping
from arro3.compute._arith import neg as neg
from arro3.compute._arith import neg_wrapping as neg_wrapping
from arro3.compute._arith import rem as rem
from arro3.compute._arith import sub as sub
from arro3.compute._arith import sub_wrapping as sub_wrapping

@overload
def cast(
Expand Down
55 changes: 55 additions & 0 deletions arro3-compute/src/arith.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use arrow::compute::kernels::numeric;
use pyo3::prelude::*;
use pyo3_arrow::error::PyArrowResult;
use pyo3_arrow::input::AnyDatum;
use pyo3_arrow::PyArray;

#[pyfunction]
pub fn add(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::add(&lhs, &rhs)?).to_arro3(py)?)
}

#[pyfunction]
pub fn add_wrapping(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::add_wrapping(&lhs, &rhs)?).to_arro3(py)?)
}

#[pyfunction]
pub fn div(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::div(&lhs, &rhs)?).to_arro3(py)?)
}

#[pyfunction]
pub fn mul(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::mul(&lhs, &rhs)?).to_arro3(py)?)
}

#[pyfunction]
pub fn mul_wrapping(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::mul_wrapping(&lhs, &rhs)?).to_arro3(py)?)
}

#[pyfunction]
pub fn neg(py: Python, array: PyArray) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::neg(array.as_ref())?).to_arro3(py)?)
}

#[pyfunction]
pub fn neg_wrapping(py: Python, array: PyArray) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::neg_wrapping(array.as_ref())?).to_arro3(py)?)
}

#[pyfunction]
pub fn rem(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::rem(&lhs, &rhs)?).to_arro3(py)?)
}

#[pyfunction]
pub fn sub(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::sub(&lhs, &rhs)?).to_arro3(py)?)
}

#[pyfunction]
pub fn sub_wrapping(py: Python, lhs: AnyDatum, rhs: AnyDatum) -> PyArrowResult<PyObject> {
Ok(PyArray::from_array_ref(numeric::sub_wrapping(&lhs, &rhs)?).to_arro3(py)?)
}
11 changes: 11 additions & 0 deletions arro3-compute/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use pyo3::prelude::*;

mod aggregate;
mod arith;
mod cast;
mod concat;
mod dictionary;
Expand All @@ -20,6 +21,16 @@ fn _compute(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(aggregate::max))?;
m.add_wrapped(wrap_pyfunction!(aggregate::min))?;
m.add_wrapped(wrap_pyfunction!(aggregate::sum))?;
m.add_wrapped(wrap_pyfunction!(arith::add_wrapping))?;
m.add_wrapped(wrap_pyfunction!(arith::add))?;
m.add_wrapped(wrap_pyfunction!(arith::div))?;
m.add_wrapped(wrap_pyfunction!(arith::mul_wrapping))?;
m.add_wrapped(wrap_pyfunction!(arith::mul))?;
m.add_wrapped(wrap_pyfunction!(arith::neg_wrapping))?;
m.add_wrapped(wrap_pyfunction!(arith::neg))?;
m.add_wrapped(wrap_pyfunction!(arith::rem))?;
m.add_wrapped(wrap_pyfunction!(arith::sub_wrapping))?;
m.add_wrapped(wrap_pyfunction!(arith::sub))?;
m.add_wrapped(wrap_pyfunction!(cast::cast))?;
m.add_wrapped(wrap_pyfunction!(concat::concat))?;
m.add_wrapped(wrap_pyfunction!(dictionary::dictionary_dictionary))?;
Expand Down
16 changes: 16 additions & 0 deletions arro3-core/python/arro3/core/_core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1176,7 +1176,23 @@ class RecordBatchReader:

class Scalar:
"""An arrow Scalar."""
def __arrow_c_array__(
self, requested_schema: object | None = None
) -> tuple[object, object]:
"""
An implementation of the [Arrow PyCapsule
Interface](https://arrow.apache.org/docs/format/CDataInterface/PyCapsuleInterface.html).
This dunder method should not be called directly, but enables zero-copy data
transfer to other Python libraries that understand Arrow memory.
For example, you can call [`pyarrow.array()`][pyarrow.array] to
convert this Scalar into a pyarrow Array, without copying memory. The generated
array is guaranteed to have length 1.
"""
def __repr__(self) -> str: ...
@classmethod
def from_arrow_pycapsule(cls, schema_capsule, array_capsule) -> Scalar:
"""Construct this object from bare Arrow PyCapsules"""
def as_py(self) -> Any: ...
@property
def is_valid(self) -> bool: ...
Expand Down
16 changes: 14 additions & 2 deletions pyo3-arrow/src/ffi/from_python/input.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::array_reader::PyArrayReader;
use crate::input::{AnyArray, AnyRecordBatch};
use crate::{PyArray, PyRecordBatch, PyRecordBatchReader};
use crate::input::{AnyArray, AnyDatum, AnyRecordBatch};
use crate::{PyArray, PyRecordBatch, PyRecordBatchReader, PyScalar};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::{PyAny, PyResult};
Expand Down Expand Up @@ -32,3 +32,15 @@ impl<'a> FromPyObject<'a> for AnyArray {
}
}
}

impl<'a> FromPyObject<'a> for AnyDatum {
fn extract_bound(ob: &Bound<'a, PyAny>) -> PyResult<Self> {
let array = PyArray::extract_bound(ob)?;
if array.as_ref().len() == 1 {
let (array, field) = array.into_inner();
Ok(Self::Scalar(PyScalar::try_new(array, field)?))
} else {
Ok(Self::Array(array))
}
}
}
33 changes: 31 additions & 2 deletions pyo3-arrow/src/input.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@ use std::collections::HashMap;
use std::string::FromUtf8Error;
use std::sync::Arc;

use arrow_array::{RecordBatchIterator, RecordBatchReader};
use arrow_array::{Datum, RecordBatchIterator, RecordBatchReader};
use arrow_schema::{ArrowError, Field, FieldRef, Fields, Schema, SchemaRef};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;

use crate::array_reader::PyArrayReader;
use crate::error::PyArrowResult;
use crate::ffi::{ArrayIterator, ArrayReader};
use crate::{PyArray, PyChunkedArray, PyField, PyRecordBatch, PyRecordBatchReader, PyTable};
use crate::{
PyArray, PyChunkedArray, PyField, PyRecordBatch, PyRecordBatchReader, PyScalar, PyTable,
};

/// An enum over [PyRecordBatch] and [PyRecordBatchReader], used when a function accepts either
/// Arrow object as input.
Expand Down Expand Up @@ -98,6 +100,33 @@ impl AnyArray {
}
}

/// An enum over [PyArray] and [PyScalar], used for functions that accept
pub enum AnyDatum {
/// A single Array, held in a [PyArray].
Array(PyArray),
/// An Arrow Scalar, held in a [pyScalar]
Scalar(PyScalar),
}

impl AnyDatum {
/// Access the field of this object.
pub fn field(&self) -> &FieldRef {
match self {
Self::Array(inner) => inner.field(),
Self::Scalar(inner) => inner.field(),
}
}
}

impl Datum for AnyDatum {
fn get(&self) -> (&dyn arrow_array::Array, bool) {
match self {
Self::Array(inner) => inner.get(),
Self::Scalar(inner) => inner.get(),
}
}
}

#[derive(FromPyObject)]
pub(crate) enum MetadataInput {
String(HashMap<String, String>),
Expand Down
17 changes: 16 additions & 1 deletion pyo3-arrow/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,22 @@ impl PyScalar {
Self::try_new(array, field)
}

/// Export to an arro3.core.Array.
/// Access the underlying [ArrayRef].
pub fn array(&self) -> &ArrayRef {
&self.array
}

/// Access the underlying [FieldRef].
pub fn field(&self) -> &FieldRef {
&self.field
}

/// Consume self to access the underlying [ArrayRef] and [FieldRef].
pub fn into_inner(self) -> (ArrayRef, FieldRef) {
(self.array, self.field)
}

/// Export to an arro3.core.Scalar.
///
/// This requires that you depend on arro3-core from your Python package.
pub fn to_arro3(&self, py: Python) -> PyResult<PyObject> {
Expand Down
18 changes: 18 additions & 0 deletions tests/compute/test_arith.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import arro3.compute as ac
import pyarrow as pa
from arro3.core import Array, DataType


def test_add():
arr1 = Array([1, 2, 3], DataType.int16())
assert ac.min(arr1).as_py() == 1

arr2 = Array([3, 2, 0], DataType.int16())
assert ac.min(arr2).as_py() == 0

add1 = ac.add(arr1, arr2)
assert pa.array(add1) == pa.array(Array([4, 4, 3], DataType.int16()))

s = arr1[0]
add2 = ac.add(arr1, s)
assert pa.array(add2) == pa.array(Array([2, 3, 4], DataType.int16()))

0 comments on commit a81ac6e

Please sign in to comment.