Skip to content

Commit

Permalink
refactor: move _parameters_XXX functions into _parameters submodu…
Browse files Browse the repository at this point in the history
…le (#2325)
  • Loading branch information
agoose77 authored Mar 20, 2023
1 parent 381f3db commit 6adf9a4
Show file tree
Hide file tree
Showing 36 changed files with 346 additions and 343 deletions.
11 changes: 8 additions & 3 deletions src/awkward/_broadcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._nplikes.shape import unknown_length
from awkward._parameters import (
parameters_are_empty,
parameters_are_equal,
parameters_intersect,
)
from awkward._typing import Any, Callable, Dict, List, TypeAlias, Union
from awkward._util import unset
from awkward.contents.bitmaskedarray import BitMaskedArray
Expand Down Expand Up @@ -240,7 +245,7 @@ def all_or_nothing_parameters_factory(
first_parameters = input_parameters[0]
# Ensure all parameters match, or set parameters to None
for other_parameters in input_parameters[1:]:
if not ak.forms.form._parameters_equal(first_parameters, other_parameters):
if not parameters_are_equal(first_parameters, other_parameters):
break
else:
parameters = first_parameters
Expand Down Expand Up @@ -278,14 +283,14 @@ def intersection_parameters_factory(
# If we encounter None-parameters, then we stop early
# as there can be no intersection.
for parameters in input_parameters:
if ak.forms.form._parameters_is_empty(parameters):
if parameters_are_empty(parameters):
break
else:
parameters_to_intersect.append(parameters)
# Otherwise, build the intersected parameter dict
else:
intersected_parameters = functools.reduce(
ak.forms.form._parameters_intersect, parameters_to_intersect
parameters_intersect, parameters_to_intersect
)

def apply(n_outputs: int) -> list[dict[str, Any] | None]:
Expand Down
6 changes: 3 additions & 3 deletions src/awkward/_connect/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import awkward as ak
from awkward._nplikes.numpy import Numpy
from awkward._nplikes.numpylike import NumpyMetadata
from awkward.forms.form import _parameters_union
from awkward._parameters import parameters_union

np = NumpyMetadata.instance()
numpy = Numpy.instance()
Expand Down Expand Up @@ -407,7 +407,7 @@ def popbuffers(paarray, awkwardarrow_type, storage_type, buffers, generate_bitma

content = handle_arrow(paarray.dictionary, generate_bitmasks)

parameters = ak.forms.form._parameters_union(
parameters = parameters_union(
mask_parameters(awkwardarrow_type), node_parameters(awkwardarrow_type)
)
if parameters is None:
Expand Down Expand Up @@ -702,7 +702,7 @@ def form_popbuffers(awkwardarrow_type, storage_type):
a, b = to_awkwardarrow_storage_types(storage_type.value_type)
content = form_popbuffers(a, b)

parameters = _parameters_union(
parameters = parameters_union(
mask_parameters(awkwardarrow_type), node_parameters(awkwardarrow_type)
)
if parameters is None:
Expand Down
185 changes: 185 additions & 0 deletions src/awkward/_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
from __future__ import annotations

from collections.abc import Collection

from awkward._typing import JSONMapping, JSONSerializable


def type_parameters_equal(
one: JSONMapping | None, two: JSONMapping | None, *, allow_missing: bool = False
) -> bool:
if one is None and two is None:
return True

elif one is None:
# NB: __categorical__ is currently a type-only parameter, but
# we check it here as types check this too.
for key in ("__array__", "__record__", "__categorical__"):
if two.get(key) is not None:
return allow_missing
return True

elif two is None:
for key in ("__array__", "__record__", "__categorical__"):
if one.get(key) is not None:
return allow_missing
return True

else:
for key in ("__array__", "__record__", "__categorical__"):
if one.get(key) != two.get(key):
return False
return True


def parameters_are_equal(
one: JSONMapping, two: JSONMapping, only_array_record=False
) -> bool:
if one is None and two is None:
return True
elif one is None:
if only_array_record:
# NB: __categorical__ is currently a type-only parameter, but
# we check it here as types check this too.
for key in ("__array__", "__record__", "__categorical__"):
if two.get(key) is not None:
return False
return True
else:
for value in two.values():
if value is not None:
return False
return True

elif two is None:
if only_array_record:
for key in ("__array__", "__record__", "__categorical__"):
if one.get(key) is not None:
return False
return True
else:
for value in one.values():
if value is not None:
return False
return True

else:
if only_array_record:
keys = ("__array__", "__record__", "__categorical__")
else:
keys = set(one.keys()).union(two.keys())
for key in keys:
if one.get(key) != two.get(key):
return False
return True


def parameters_intersect(
left: JSONMapping | None,
right: JSONMapping | None,
*,
exclude: Collection[tuple[str, JSONSerializable]] = (),
) -> JSONMapping | None:
"""
Args:
left: first parameters mapping
right: second parameters mapping
exclude: collection of (key, value) items to exclude
Returns the intersected key-value pairs of `left` and `right` as a dictionary.
"""
if left is None or right is None:
return None

common_keys = iter(left.keys() & right.keys())
has_no_exclusions = len(exclude) == 0

# Avoid creating `result` unless we have to
for key in common_keys:
left_value = left[key]
# Do our keys match?
if (
left_value is not None
and left_value == right[key]
and (has_no_exclusions or (key, left_value) not in exclude)
):
# Exit, indicating that we want to create `result`
break
else:
return None

# We found a meaningful key, so create a result dict
result = {key: left_value}
for key in common_keys:
left_value = left[key]
if (
left_value is not None
and left_value == right[key]
and (has_no_exclusions or (key, left_value) not in exclude)
):
result[key] = left_value

return result


def parameters_union(
left: JSONMapping | None,
right: JSONMapping | None,
*,
exclude: Collection[tuple[str, JSONSerializable]] = (),
) -> JSONMapping | None:
"""
Args:
left: first parameters mapping
right: second parameters mapping
exclude: collection of (key, value) items to exclude
Returns the merged key-value pairs of `left` and `right` as a dictionary.
"""
has_no_exclusions = len(exclude) == 0
if left is None:
if right is None:
return None
else:
return {
k: v
for k, v in right.items()
if v is not None and (has_no_exclusions or (k, v) not in exclude)
}
else:
result = {
k: v
for k, v in left.items()
if v is not None and (has_no_exclusions or (k, v) not in exclude)
}
if right is None:
return result
else:
for key in right:
right_value = right[key]
if right_value is not None and (
has_no_exclusions or (key, right_value) not in exclude
):
result[key] = right_value

return result


def parameters_are_empty(parameters: JSONMapping | None) -> bool:
"""
Args:
parameters (dict or None): parameters dictionary, or None
Return True if the parameters dictionary is considered empty, either because it is
None, or because it does not have any meaningful (non-None) values; otherwise,
return False.
"""
if parameters is None:
return True

for item in parameters.values():
if item is not None:
return False

return True
6 changes: 6 additions & 0 deletions src/awkward/_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,9 @@
final,
runtime_checkable,
)


JSONSerializable: TypeAlias = (
"str | int | float | bool | None | list | tuple | JSONMapping"
)
JSONMapping: TypeAlias = "dict[str, JSONSerializable]"
6 changes: 4 additions & 2 deletions src/awkward/contents/bitmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@
from awkward._nplikes.numpylike import IndexType, NumpyMetadata
from awkward._nplikes.shape import unknown_length
from awkward._nplikes.typetracer import MaybeNone, TypeTracer
from awkward._parameters import (
type_parameters_equal,
)
from awkward._regularize import is_integer, is_integer_like
from awkward._slicing import NO_HEAD
from awkward._typing import TYPE_CHECKING, Final, Self, SupportsIndex, final
from awkward._util import unset
from awkward.contents.bytemaskedarray import ByteMaskedArray
from awkward.contents.content import Content
from awkward.forms.bitmaskedform import BitMaskedForm
from awkward.forms.form import _type_parameters_equal
from awkward.index import Index

if TYPE_CHECKING:
Expand Down Expand Up @@ -570,7 +572,7 @@ def _mergeable_next(self, other, mergebool):
elif other.is_option or other.is_indexed:
return self._content._mergeable_next(
other.content, mergebool
) and _type_parameters_equal(self._parameters, other._parameters)
) and type_parameters_equal(self._parameters, other._parameters)
else:
return self._content._mergeable_next(other, mergebool)

Expand Down
11 changes: 6 additions & 5 deletions src/awkward/contents/bytemaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
from awkward._nplikes.numpylike import IndexType, NumpyMetadata
from awkward._nplikes.shape import unknown_length
from awkward._nplikes.typetracer import MaybeNone, TypeTracer
from awkward._parameters import (
parameters_intersect,
type_parameters_equal,
)
from awkward._regularize import is_integer_like
from awkward._slicing import NO_HEAD
from awkward._typing import TYPE_CHECKING, Final, Self, SupportsIndex, final
from awkward._util import unset
from awkward.contents.content import Content
from awkward.forms.bytemaskedform import ByteMaskedForm
from awkward.forms.form import _type_parameters_equal
from awkward.index import Index

if TYPE_CHECKING:
Expand Down Expand Up @@ -707,7 +710,7 @@ def _mergeable_next(self, other, mergebool):
elif other.is_option or other.is_indexed:
return self._content._mergeable_next(
other.content, mergebool
) and _type_parameters_equal(self._parameters, other._parameters)
) and type_parameters_equal(self._parameters, other._parameters)
else:
return self._content._mergeable_next(other, mergebool)

Expand All @@ -731,9 +734,7 @@ def _mergemany(self, others):
length = 0
for x in others:
length_scalar = self._backend.index_nplike.shape_item_as_index(x.length)
parameters = ak.forms.form._parameters_intersect(
parameters, x._parameters
)
parameters = parameters_intersect(parameters, x._parameters)
masks.append(x._mask.data[:length_scalar])
tail_contents.append(x._content[:length_scalar])
length += x.length
Expand Down
8 changes: 6 additions & 2 deletions src/awkward/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,24 @@
from awkward._nplikes.numpylike import IndexType, NumpyMetadata
from awkward._nplikes.shape import ShapeItem, unknown_length
from awkward._nplikes.typetracer import TypeTracer
from awkward._parameters import (
type_parameters_equal,
)
from awkward._regularize import is_integer_like, is_sized_iterable
from awkward._slicing import normalize_slice
from awkward._typing import (
TYPE_CHECKING,
Any,
AxisMaybeNone,
JSONMapping,
Literal,
Self,
SupportsIndex,
TypeAlias,
TypedDict,
)
from awkward._util import unset
from awkward.forms.form import Form, JSONMapping, _type_parameters_equal
from awkward.forms.form import Form
from awkward.index import Index, Index64

if TYPE_CHECKING:
Expand Down Expand Up @@ -1333,7 +1337,7 @@ def is_equal_to(
return (
self.__class__ is other.__class__
and len(self) == len(other)
and _type_parameters_equal(self.parameters, other.parameters)
and type_parameters_equal(self.parameters, other.parameters)
and self._is_equal_to(other, index_dtype, numpyarray)
)

Expand Down
Loading

0 comments on commit 6adf9a4

Please sign in to comment.