Skip to content

Commit

Permalink
fix: support revertable for concatenate in pyarrow logic (#2889)
Browse files Browse the repository at this point in the history
* wip: fix revertable for concatenate

* test: ensure behavior!

* test: fix importorskip arrow
  • Loading branch information
agoose77 authored Dec 11, 2023
1 parent bf7e37f commit ae5923e
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 2 deletions.
34 changes: 32 additions & 2 deletions src/awkward/_connect/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,8 +919,15 @@ def direct_Content_subclass_name(node):
return out.__name__


def is_revertable(akarray):
return hasattr(akarray, "__pyarrow_original")


def remove_optiontype(akarray):
return akarray.__pyarrow_original
if callable(akarray.__pyarrow_original):
return akarray.__pyarrow_original()
else:
return akarray.__pyarrow_original


def form_remove_optiontype(akform):
Expand All @@ -944,6 +951,17 @@ def handle_arrow(obj, generate_bitmasks=False, pass_empty_field=False):

if len(layouts) == 1:
return layouts[0]
elif any(is_revertable(arr) for arr in layouts):
assert all(is_revertable(arr) for arr in layouts)
# TODO: the callable argument to revertable is a premature(?) optimisation.
# it would be better to obviate the need to compute both revertable and non revertable branches
# e.g. by requesting a particular layout kind from the next `frombuffers` operation
return revertable(
ak.operations.concatenate(layouts, highlevel=False),
lambda: ak.operations.concatenate(
[remove_optiontype(x) for x in layouts], highlevel=False
),
)
else:
return ak.operations.concatenate(layouts, highlevel=False)

Expand Down Expand Up @@ -1044,7 +1062,19 @@ def handle_arrow(obj, generate_bitmasks=False, pass_empty_field=False):
for batch in batches
if len(batch) > 0
]
return ak.operations.concatenate(arrays, highlevel=False)
if any(is_revertable(arr) for arr in arrays):
assert all(is_revertable(arr) for arr in arrays)
# TODO: the callable argument to revertable is a premature(?) optimisation.
# it would be better to obviate the need to compute both revertable and non revertable branches
# e.g. by requesting a particular layout kind from the next `frombuffers` operation
return revertable(
ak.operations.concatenate(arrays, highlevel=False),
lambda: ak.operations.concatenate(
[remove_optiontype(x) for x in arrays], highlevel=False
),
)
else:
return ak.operations.concatenate(arrays, highlevel=False)

elif (
isinstance(obj, Iterable)
Expand Down
49 changes: 49 additions & 0 deletions tests/test_2889_test_chunked_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import pytest

import awkward as ak

pa = pytest.importorskip("pyarrow")


def test_strings():
array = pa.chunked_array([["foo", "bar"], ["blah", "bleh"]])
ak_array = ak.from_arrow(array)
assert ak_array.type == ak.types.ArrayType(
ak.types.ListType(
ak.types.NumpyType("uint8", parameters={"__array__": "char"}),
parameters={"__array__": "string"},
),
4,
)


def test_strings_option():
array = pa.chunked_array([["foo", "bar"], ["blah", "bleh", None]])
ak_array = ak.from_arrow(array)
assert ak_array.type == ak.types.ArrayType(
ak.types.OptionType(
ak.types.ListType(
ak.types.NumpyType("uint8", parameters={"__array__": "char"}),
parameters={"__array__": "string"},
)
),
5,
)


def test_numbers():
array = pa.chunked_array([[1, 2, 3], [4, 5]])
ak_array = ak.from_arrow(array)
assert ak_array.type == ak.types.ArrayType(ak.types.NumpyType("int64"), 5)


def test_numbers_option():
array = pa.chunked_array([[1, 2, 3], [4, 5, None]])
ak_array = ak.from_arrow(array)
assert ak_array.type == ak.types.ArrayType(
ak.types.OptionType(ak.types.NumpyType("int64")), 6
)

0 comments on commit ae5923e

Please sign in to comment.