Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: preserve dimensions for keepdims=True, axis=None reductions #2177

Merged
merged 6 commits into from
Jan 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/awkward/_do.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,31 +217,34 @@ def pad_none(
return layout._pad_none(length, axis, 1, clip)


def completely_flatten(
def remove_structure(
layout: Content | Record,
backend: Backend | None = None,
flatten_records: bool = True,
function_name: str | None = None,
drop_nones: bool = True,
keepdims: bool = False,
):
if isinstance(layout, Record):
return completely_flatten(
return remove_structure(
layout._array[layout._at : layout._at + 1],
backend,
flatten_records,
function_name,
drop_nones,
keepdims,
)

else:
if backend is None:
backend = layout._backend
arrays = layout._completely_flatten(
arrays = layout._remove_structure(
backend,
{
"flatten_records": flatten_records,
"function_name": function_name,
"drop_nones": drop_nones,
"keepdims": keepdims,
},
)
return tuple(arrays)
Expand Down Expand Up @@ -314,15 +317,16 @@ def reduce(
behavior: dict | None = None,
):
if axis is None:
parts = completely_flatten(layout, flatten_records=False, drop_nones=False)
parts = remove_structure(
layout, flatten_records=False, drop_nones=False, keepdims=keepdims
)

if len(parts) > 1:
# We know that `flatten_records` must fail, so the only other type
# that can return multiple parts here is the union array
raise ak._errors.wrap_error(
ValueError(
"cannot use axis=None with keepdims=True on an array containing "
"irreducible unions"
"cannot use axis=None on an array containing irreducible unions"
)
)
elif len(parts) == 0:
Expand Down
9 changes: 8 additions & 1 deletion src/awkward/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,7 @@ def arrays_approx_equal(
atol: float = 1e-8,
dtype_exact: bool = True,
check_parameters=True,
check_regular=True,
) -> bool:
# TODO: this should not be needed after refactoring nplike mechanism
import awkward.forms.form
Expand Down Expand Up @@ -798,7 +799,13 @@ def visitor(left, right) -> bool:
right = right.to_IndexedOptionArray64()

if type(left) is not type(right):
return False
if not check_regular and (
left.is_list and right.is_regular or left.is_regular and right.is_list
):
left = left.to_ListOffsetArray64()
right = right.to_ListOffsetArray64()
else:
return False
Comment on lines +802 to +808
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use this here, but I wrote it, and then removed the usage. I think it's good to keep for future tests.


if left.length != right.length:
return False
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/contents/bitmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,10 +585,10 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
def _to_backend_array(self, allow_missing, backend):
return self.to_ByteMaskedArray()._to_backend_array(allow_missing, backend)

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
branch, depth = self.branch_depth
if branch or options["drop_nones"] or depth > 1:
return self.project()._completely_flatten(backend, options)
return self.project()._remove_structure(backend, options)
else:
return [self]

Expand Down
4 changes: 2 additions & 2 deletions src/awkward/contents/bytemaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,10 +972,10 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
def _to_backend_array(self, allow_missing, backend):
return self.to_IndexedOptionArray64()._to_backend_array(allow_missing, backend)

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
branch, depth = self.branch_depth
if branch or options["drop_nones"] or depth > 1:
return self.project()._completely_flatten(backend, options)
return self.project()._remove_structure(backend, options)
else:
return [self]

Expand Down
2 changes: 1 addition & 1 deletion src/awkward/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,7 +1095,7 @@ def drop_none(self):
def _drop_none(self) -> Content:
raise ak._errors.wrap_error(NotImplementedError)

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
raise ak._errors.wrap_error(NotImplementedError)

def _recursively_apply(
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/contents/emptyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
def _to_backend_array(self, allow_missing, backend):
return backend.nplike.empty(0, dtype=np.float64)

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
return []

def _recursively_apply(
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/contents/indexedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,8 +958,8 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
def _to_backend_array(self, allow_missing, backend):
return self.project()._to_backend_array(allow_missing, backend)

def _completely_flatten(self, backend, options):
return self.project()._completely_flatten(backend, options)
def _remove_structure(self, backend, options):
return self.project()._remove_structure(backend, options)

def _recursively_apply(
self, action, behavior, depth, depth_context, lateral_context, options
Expand Down
6 changes: 3 additions & 3 deletions src/awkward/contents/indexedoptionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,7 +1331,7 @@ def _reduce_next(
"reduce_next with unbranching depth > negaxis is only "
"expected to return RegularArray or ListOffsetArray or "
"IndexedOptionArray; "
"instead, it returned " + out
"instead, it returned {}".format(type(out).__name__)
)
)

Expand Down Expand Up @@ -1526,10 +1526,10 @@ def _to_backend_array(self, allow_missing, backend):
else:
return content

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
branch, depth = self.branch_depth
if branch or options["drop_nones"] or depth > 1:
return self.project()._completely_flatten(backend, options)
return self.project()._remove_structure(backend, options)
else:
return [self]

Expand Down
12 changes: 2 additions & 10 deletions src/awkward/contents/listarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1379,16 +1379,8 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
def _to_backend_array(self, allow_missing, backend):
return self.to_RegularArray()._to_backend_array(allow_missing, backend)

def _completely_flatten(self, backend, options):
if (
self.parameter("__array__") == "string"
or self.parameter("__array__") == "bytestring"
):
return [self]
else:
next = self.to_ListOffsetArray64(False)
flat = next.content[next.offsets[0] : next.offsets[-1]]
return flat._completely_flatten(backend, options)
def _remove_structure(self, backend, options):
return self.to_ListOffsetArray64(False)._remove_structure(backend, options)

def _drop_none(self):
return self.to_ListOffsetArray64()._drop_none()
Expand Down
21 changes: 18 additions & 3 deletions src/awkward/contents/listoffsetarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1962,15 +1962,30 @@ def _to_backend_array(self, allow_missing, backend):

return self.to_RegularArray()._to_backend_array(allow_missing, backend)

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
if (
self.parameter("__array__") == "string"
or self.parameter("__array__") == "bytestring"
):
return [self]
else:
flat = self._content[self._offsets[0] : self._offsets[-1]]
return flat._completely_flatten(backend, options)
content = self._content[self._offsets[0] : self._offsets[-1]]
contents = content._remove_structure(backend, options)
if options["keepdims"]:
return [
ListOffsetArray(
ak.index.Index64(
backend.index_nplike.asarray(
[0, backend.index_nplike.shape_item_as_scalar(c.length)]
)
),
c,
parameters=self._parameters,
)
for c in contents
]
else:
return contents

def _drop_none(self):
if self._content.is_option:
Expand Down
8 changes: 6 additions & 2 deletions src/awkward/contents/numpyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1204,10 +1204,14 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
def _to_backend_array(self, allow_missing, backend):
return to_nplike(self.data, backend.nplike, from_nplike=self._backend.nplike)

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
if options["keepdims"]:
shape = (1,) * (self._data.ndim - 1) + (-1,)
else:
shape = (-1,)
return [
ak.contents.NumpyArray(
backend.nplike.reshape(self._raw(backend.nplike), (-1,)),
backend.nplike.reshape(self._raw(backend.nplike), shape),
backend=backend,
)
]
Expand Down
6 changes: 2 additions & 4 deletions src/awkward/contents/recordarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,13 +922,11 @@ def _to_backend_array(self, allow_missing, backend):

return out

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
if options["flatten_records"]:
out = []
for content in self._contents:
out.extend(
content[: self._length]._completely_flatten(backend, options)
)
out.extend(content[: self._length]._remove_structure(backend, options))
return out
else:
in_function = ""
Expand Down
13 changes: 10 additions & 3 deletions src/awkward/contents/regulararray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1218,7 +1218,7 @@ def _to_arrow(self, pyarrow, mask_node, validbytes, length, options):
),
)

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
if (
self.parameter("__array__") == "string"
or self.parameter("__array__") == "bytestring"
Expand All @@ -1227,8 +1227,15 @@ def _completely_flatten(self, backend, options):
else:
index_nplike = self._backend.index_nplike
length = index_nplike.mul_shape_item(self._length, self._size)
flat = self._content[: index_nplike.shape_item_as_scalar(length)]
return flat._completely_flatten(backend, options)
content = self._content[: index_nplike.shape_item_as_scalar(length)]
contents = content._remove_structure(backend, options)
if options["keepdims"]:
return [
RegularArray(c, size=c.length, parameters=self._parameters)
for c in contents
]
else:
return contents

def _drop_none(self):
return self.to_ListOffsetArray64()._drop_none()
Expand Down
4 changes: 2 additions & 2 deletions src/awkward/contents/unionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1469,14 +1469,14 @@ def _to_backend_array(self, allow_missing, backend):

return out

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
out = []
for i in range(len(self._contents)):
index = self._index[self._tags.data == i]
out.extend(
self._contents[i]
._carry(index, False)
._completely_flatten(backend, options)
._remove_structure(backend, options)
)
return out

Expand Down
4 changes: 2 additions & 2 deletions src/awkward/contents/unmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,10 +424,10 @@ def _to_backend_array(self, allow_missing, backend):
else:
return content

def _completely_flatten(self, backend, options):
def _remove_structure(self, backend, options):
branch, depth = self.branch_depth
if branch or options["drop_nones"] or depth > 1:
return self.project()._completely_flatten(backend, options)
return self.project()._remove_structure(backend, options)
else:
return [self]

Expand Down
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _impl(array, axis, highlevel, behavior):
layout = ak.operations.to_layout(array, allow_record=False, allow_other=False)

if axis is None:
out = ak._do.completely_flatten(layout, function_name="ak.flatten")
out = ak._do.remove_structure(layout, function_name="ak.flatten")
assert isinstance(out, tuple) and all(
isinstance(x, ak.contents.NumpyArray) for x in out
)
Expand Down
2 changes: 1 addition & 1 deletion src/awkward/operations/ak_ravel.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def ravel(array, *, highlevel=True, behavior=None):
def _impl(array, highlevel, behavior):
layout = ak.operations.to_layout(array, allow_record=False, allow_other=False)

out = ak._do.completely_flatten(layout, function_name="ak.ravel", drop_nones=False)
out = ak._do.remove_structure(layout, function_name="ak.ravel", drop_nones=False)
assert isinstance(out, tuple) and all(
isinstance(x, ak.contents.Content) for x in out
)
Expand Down
Loading