Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add ak.approx_equal #2198

Merged
merged 15 commits into from
Feb 3, 2023
5 changes: 5 additions & 0 deletions docs/reference/toctree.txt
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,11 @@
generated/ak.isclose
generated/ak.ones_like
generated/ak.zeros_like

.. toctree::
:caption: Array comparison

generated/ak.almost_equal

.. toctree::
:caption: Third-party integration
Expand Down
103 changes: 0 additions & 103 deletions src/awkward/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,109 +715,6 @@ def maybe_posaxis(layout, axis, depth):
return None


def arrays_approx_equal(
left,
right,
rtol: float = 1e-5,
atol: float = 1e-8,
dtype_exact: bool = True,
check_parameters=True,
check_regular=True,
) -> bool:
# TODO: this should not be needed after refactoring nplike mechanism
import awkward.forms.form

left_behavior = ak._util.behavior_of(left)
right_behavior = ak._util.behavior_of(right)

left = ak.to_packed(ak.to_layout(left, allow_record=False), highlevel=False)
right = ak.to_packed(ak.to_layout(right, allow_record=False), highlevel=False)

nplike = nplike_of(left, right)

def is_approx_dtype(left, right) -> bool:
if not dtype_exact:
for family in np.integer, np.floating:
if np.issubdtype(left, family):
return np.issubdtype(right, family)
return left == right

def visitor(left, right) -> bool:
# Enforce super-canonicalisation rules
if left.is_option:
left = left.to_IndexedOptionArray64()
if right.is_option:
right = right.to_IndexedOptionArray64()

if type(left) is not type(right):
if not check_regular and (
left.is_list and right.is_regular or left.is_regular and right.is_list
):
left = left.to_ListOffsetArray64()
right = right.to_ListOffsetArray64()
else:
return False

if left.length != right.length:
return False

if check_parameters and not awkward.forms.form._parameters_equal(
left.parameters, right.parameters
):
return False

# Require that the arrays have the same evaluated types
if not (
arrayclass(left, left_behavior) is arrayclass(right, right_behavior)
or not check_parameters
):
return False

if left.is_list:
return (
nplike.array_equal(left.starts, right.starts)
and nplike.array_equal(left.stops, right.stops)
and visitor(left.content, right.content)
)
elif left.is_regular:
return (left.size == right.size) and visitor(left.content, right.content)
elif left.is_numpy:
return is_approx_dtype(left.dtype, right.dtype) and nplike.all(
nplike.isclose(
left.data, right.data, rtol=rtol, atol=atol, equal_nan=False
)
)
elif left.is_option:
return nplike.array_equal(
left.index.data < 0, right.index.data < 0
) and visitor(left.project(), right.project())
elif left.is_union:
return (len(left.contents) == len(right.contents)) and all(
[
visitor(left.project(i).to_packed(), right.project(i).to_packed())
for i, _ in enumerate(left.contents)
]
)
elif left.is_record:
return (
(
recordclass(left, left_behavior)
is recordclass(right, right_behavior)
or not check_parameters
)
and (left.fields == right.fields)
and (left.is_tuple == right.is_tuple)
and all([visitor(x, y) for x, y in zip(left.contents, right.contents)])
)
elif left.is_unknown:
return True

else:
raise ak._errors.wrap_error(AssertionError)

return visitor(left, right)


try:
import numpy # noqa: TID251

Expand Down
1 change: 1 addition & 0 deletions src/awkward/operations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE

from awkward.operations.ak_all import all
from awkward.operations.ak_almost_equal import almost_equal
from awkward.operations.ak_any import any
from awkward.operations.ak_argcartesian import argcartesian
from awkward.operations.ak_argcombinations import argcombinations
Expand Down
134 changes: 134 additions & 0 deletions src/awkward/operations/ak_almost_equal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
from __future__ import annotations

__all__ = ["almost_equal"]

from awkward._backends import backend_of
from awkward._errors import wrap_error
from awkward._nplikes.numpylike import NumpyMetadata
from awkward._util import arrayclass, behavior_of, recordclass
from awkward.forms.form import _parameters_equal
from awkward.operations.ak_to_layout import to_layout

np = NumpyMetadata.instance()


def almost_equal(
left,
right,
*,
rtol: float = 1e-5,
atol: float = 1e-8,
dtype_exact: bool = True,
check_parameters: bool = True,
check_regular: bool = True,
) -> bool:
"""
Args:
left: Array-like data (anything #ak.to_layout recognizes).
right: Array-like data (anything #ak.to_layout recognizes).
rtol: the relative tolerance parameter (see below).
atol: the absolute tolerance parameter (see below).
dtype_exact: whether the dtypes must be exactly the same, or just the
same family.
check_parameters: whether to compare parameters.
check_regular: whether to consider ragged and regular dimensions as
unequal.

Return True if the two array-like arguments are considered equal for the
given options. Otherwise, return False.

The relative difference (`rtol * abs(b)`) and the absolute difference `atol`
are added together to compare against the absolute difference between `left`
and `right`.
"""
left_behavior = behavior_of(left)
right_behavior = behavior_of(right)

left = to_layout(left, allow_record=False).to_packed()
right = to_layout(right, allow_record=False).to_packed()

backend = backend_of(left, right)

def is_approx_dtype(left, right) -> bool:
if not dtype_exact:
for family in np.integer, np.floating:
if np.issubdtype(left, family):
return np.issubdtype(right, family)
return left == right

def visitor(left, right) -> bool:
# Enforce super-canonicalisation rules
if left.is_option:
left = left.to_IndexedOptionArray64()
if right.is_option:
right = right.to_IndexedOptionArray64()

if type(left) is not type(right):
if not check_regular and (
left.is_list and right.is_regular or left.is_regular and right.is_list
):
left = left.to_ListOffsetArray64()
right = right.to_ListOffsetArray64()
else:
return False

if left.length != right.length:
return False

if check_parameters and not _parameters_equal(
left.parameters, right.parameters
):
return False

# Require that the arrays have the same evaluated types
if not (
arrayclass(left, left_behavior) is arrayclass(right, right_behavior)
or not check_parameters
):
return False

if left.is_list:
return (
backend.index_nplike.array_equal(left.starts, right.starts)
and backend.index_nplike.array_equal(left.stops, right.stops)
and visitor(
left.content[: left.stops[-1]], right.content[: right.stops[-1]]
)
)
elif left.is_regular:
return (left.size == right.size) and visitor(left.content, right.content)
elif left.is_numpy:
return is_approx_dtype(left.dtype, right.dtype) and backend.nplike.all(
backend.nplike.isclose(
left.data, right.data, rtol=rtol, atol=atol, equal_nan=False
)
)
elif left.is_option:
return backend.index_nplike.array_equal(
left.index.data < 0, right.index.data < 0
) and visitor(left.project(), right.project())
elif left.is_union:
return (len(left.contents) == len(right.contents)) and all(
[
visitor(left.project(i).to_packed(), right.project(i).to_packed())
for i, _ in enumerate(left.contents)
]
)
elif left.is_record:
return (
(
recordclass(left, left_behavior)
is recordclass(right, right_behavior)
or not check_parameters
)
and (left.fields == right.fields)
and (left.is_tuple == right.is_tuple)
and all([visitor(x, y) for x, y in zip(left.contents, right.contents)])
)
elif left.is_unknown:
return True

else:
raise wrap_error(AssertionError)

return visitor(left, right)
8 changes: 4 additions & 4 deletions tests/test_0355_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def weighted_add(self, other):
with_name="WeightedPoint",
behavior=behavior,
)
assert ak._util.arrays_approx_equal(
assert ak.almost_equal(
one + wone,
ak.Array(
[
Expand All @@ -90,7 +90,7 @@ def weighted_add(self, other):
),
dtype_exact=False,
)
assert ak._util.arrays_approx_equal(
assert ak.almost_equal(
wone + wtwo,
ak.Array(
[
Expand Down Expand Up @@ -122,7 +122,7 @@ def weighted_add(self, other):
),
dtype_exact=False,
)
assert ak._util.arrays_approx_equal(
assert ak.almost_equal(
abs(one),
ak.Array(
[
Expand All @@ -135,7 +135,7 @@ def weighted_add(self, other):
),
dtype_exact=False,
)
assert ak._util.arrays_approx_equal(
assert ak.almost_equal(
one.distance(wtwo),
[
[0.14142135623730953, 0.0, 0.31622776601683783],
Expand Down
10 changes: 5 additions & 5 deletions tests/test_1318_array_function_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_tuple():
(ak.from_numpy(data[0]), ak.from_numpy(data[1])), (2, 2)
)
assert isinstance(result, ak.Array)
assert ak._util.arrays_approx_equal(result, np.ravel_multi_index(data, (2, 2)))
assert ak.almost_equal(result, np.ravel_multi_index(data, (2, 2)))


def test_list():
Expand All @@ -23,7 +23,7 @@ def test_list():
]
)
assert isinstance(result, ak.Array)
assert ak._util.arrays_approx_equal(
assert ak.almost_equal(
result, np.block([[A, np.zeros((2, 3))], [np.ones((3, 2)), B]])
)

Expand All @@ -33,14 +33,14 @@ def test_array():
needle = np.array([5, 0, 2], dtype=np.int64)
result = np.searchsorted(ak.from_numpy(haystack), ak.from_numpy(needle))
assert isinstance(result, ak.Array)
assert ak._util.arrays_approx_equal(result, np.searchsorted(haystack, needle))
assert ak.almost_equal(result, np.searchsorted(haystack, needle))


def test_scalar():
data = np.array([1, 2, 3, 4, 3, 2, 1, 2], dtype=np.int64)
result = np.partition(ak.from_numpy(data), 4)
assert isinstance(result, ak.Array)
assert ak._util.arrays_approx_equal(result, np.partition(data, 4))
assert ak.almost_equal(result, np.partition(data, 4))


def test_tuple_of_array():
Expand All @@ -50,4 +50,4 @@ def test_tuple_of_array():
)
result = np.lexsort((ak.from_numpy(data[0]), ak.from_numpy(data[1])))
assert isinstance(result, ak.Array)
assert ak._util.arrays_approx_equal(result, np.lexsort(data))
assert ak.almost_equal(result, np.lexsort(data))
Loading