Skip to content

Commit

Permalink
GH-45380: [Python] Expose RankQuantileOptions to Python
Browse files Browse the repository at this point in the history
  • Loading branch information
raulcd committed Jan 30, 2025
1 parent aaa88e9 commit ff3a84d
Show file tree
Hide file tree
Showing 5 changed files with 78 additions and 0 deletions.
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/api_vector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ void RegisterVectorOptions(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunctionOptionsType(kSelectKOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kCumulativeOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kRankOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kRankQuantileOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kPairwiseOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kListFlattenOptionsType));
DCHECK_OK(registry->AddFunctionOptionsType(kInversePermutationOptionsType));
Expand Down
41 changes: 41 additions & 0 deletions python/pyarrow/_compute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -2370,6 +2370,47 @@ class RankOptions(_RankOptions):
self._set_options(sort_keys, null_placement, tiebreaker)


cdef class _RankQuantileOptions(FunctionOptions):

def _set_options(self, sort_keys, null_placement):
cdef vector[CSortKey] c_sort_keys
if isinstance(sort_keys, str):
c_sort_keys.push_back(
CSortKey(_ensure_field_ref(""), unwrap_sort_order(sort_keys))
)
else:
for name, order in sort_keys:
c_sort_keys.push_back(
CSortKey(_ensure_field_ref(name), unwrap_sort_order(order))
)
self.wrapped.reset(
new CRankQuantileOptions(c_sort_keys,
unwrap_null_placement(null_placement))
)


class RankQuantileOptions(_RankQuantileOptions):
"""
Options for the `rank_quantile` function.
Parameters
----------
sort_keys : sequence of (name, order) tuples or str, default "ascending"
Names of field/column keys to sort the input on,
along with the order each field/column is sorted in.
Accepted values for `order` are "ascending", "descending".
The field name can be a string column name or expression.
Alternatively, one can simply pass "ascending" or "descending" as a string
if the input is array-like.
null_placement : str, default "at_end"
Where nulls in input should be sorted.
Accepted values are "at_start", "at_end".
"""

def __init__(self, sort_keys="ascending", *, null_placement="at_end"):
self._set_options(sort_keys, null_placement)


cdef class Expression(_Weakrefable):
"""
A logical expression to be evaluated against some input.
Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
QuantileOptions,
RandomOptions,
RankOptions,
RankQuantileOptions,
ReplaceSliceOptions,
ReplaceSubstringOptions,
RoundBinaryOptions,
Expand Down
6 changes: 6 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2788,6 +2788,12 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
CNullPlacement null_placement
CRankOptionsTiebreaker tiebreaker

cdef cppclass CRankQuantileOptions \
"arrow::compute::RankQuantileOptions"(CFunctionOptions):
CRankQuantileOptions(vector[CSortKey] sort_keys, CNullPlacement)
vector[CSortKey] sort_keys
CNullPlacement null_placement

cdef enum DatumType" arrow::Datum::type":
DatumType_NONE" arrow::Datum::NONE"
DatumType_SCALAR" arrow::Datum::SCALAR"
Expand Down
29 changes: 29 additions & 0 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def test_option_class_equality(request):
pc.RandomOptions(),
pc.RankOptions(sort_keys="ascending",
null_placement="at_start", tiebreaker="max"),
pc.RankQuantileOptions(sort_keys="ascending",
null_placement="at_start"),
pc.ReplaceSliceOptions(0, 1, "a"),
pc.ReplaceSubstringOptions("a", "b"),
pc.RoundOptions(2, "towards_infinity"),
Expand Down Expand Up @@ -3360,6 +3362,33 @@ def test_rank_options():
tiebreaker="NonExisting")


def test_rank_quantile_options():
arr = pa.array([None, 1, None, 2, None])
expected = pa.array([0.7, 0.1, 0.7, 0.3, 0.7], type=pa.float64())

# Ensure rank_quantile can be called without specifying options
result = pc.rank_quantile(arr)
assert result.equals(expected)

# Ensure default RankOptions
result = pc.rank_quantile(arr, options=pc.RankQuantileOptions())
assert result.equals(expected)

# Ensure sort_keys tuple usage
result = pc.rank_quantile(arr, options=pc.RankQuantileOptions(
sort_keys=[("b", "ascending")])
)
assert result.equals(expected)

result = pc.rank_quantile(arr, null_placement="at_start")
expected_at_start = pa.array([0.3, 0.7, 0.3, 0.9, 0.3], type=pa.float64())
assert result.equals(expected_at_start)

result = pc.rank_quantile(arr, sort_keys="descending")
expected_descending = pa.array([0.7, 0.3, 0.7, 0.1, 0.7], type=pa.float64())
assert result.equals(expected_descending)


def create_sample_expressions():
# We need a schema for substrait conversion
schema = pa.schema([pa.field("i64", pa.int64()), pa.field(
Expand Down

0 comments on commit ff3a84d

Please sign in to comment.