Skip to content

Commit

Permalink
approx: use exact comparison for bool
Browse files Browse the repository at this point in the history
Fixes #9353
  • Loading branch information
jvansanten authored Nov 29, 2024
1 parent b938e70 commit a16e8ea
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 19 deletions.
1 change: 1 addition & 0 deletions changelog/9353.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
:func:`pytest.approx` now uses strict equality when given booleans.
44 changes: 26 additions & 18 deletions src/_pytest/python_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,19 +259,22 @@ def _repr_compare(self, other_side: Mapping[object, float]) -> list[str]:
):
if approx_value != other_value:
if approx_value.expected is not None and other_value is not None:
max_abs_diff = max(
max_abs_diff, abs(approx_value.expected - other_value)
)
if approx_value.expected == 0.0:
max_rel_diff = math.inf
else:
max_rel_diff = max(
max_rel_diff,
abs(
(approx_value.expected - other_value)
/ approx_value.expected
),
try:
max_abs_diff = max(
max_abs_diff, abs(approx_value.expected - other_value)
)
if approx_value.expected == 0.0:
max_rel_diff = math.inf
else:
max_rel_diff = max(
max_rel_diff,
abs(
(approx_value.expected - other_value)
/ approx_value.expected
),
)
except ZeroDivisionError:
pass
different_ids.append(approx_key)

message_data = [
Expand Down Expand Up @@ -395,8 +398,10 @@ def __repr__(self) -> str:
# Don't show a tolerance for values that aren't compared using
# tolerances, i.e. non-numerics and infinities. Need to call abs to
# handle complex numbers, e.g. (inf + 1j).
if (not isinstance(self.expected, (Complex, Decimal))) or math.isinf(
abs(self.expected)
if (
isinstance(self.expected, bool)
or (not isinstance(self.expected, (Complex, Decimal)))
or math.isinf(abs(self.expected) or isinstance(self.expected, bool))
):
return str(self.expected)

Expand Down Expand Up @@ -428,14 +433,17 @@ def __eq__(self, actual) -> bool:
# numpy<1.13. See #3748.
return all(self.__eq__(a) for a in asarray.flat)

# Short-circuit exact equality.
if actual == self.expected:
# Short-circuit exact equality, except for bool
if isinstance(self.expected, bool) and not isinstance(actual, bool):
return False
elif actual == self.expected:
return True

# If either type is non-numeric, fall back to strict equality.
# NB: we need Complex, rather than just Number, to ensure that __abs__,
# __sub__, and __float__ are defined.
if not (
# __sub__, and __float__ are defined. Also, consider bool to be
# nonnumeric, even though it has the required arithmetic.
if isinstance(self.expected, bool) or not (
isinstance(self.expected, (Complex, Decimal))
and isinstance(actual, (Complex, Decimal))
):
Expand Down
23 changes: 22 additions & 1 deletion testing/python/approx.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,13 +90,26 @@ def do_assert(lhs, rhs, expected_message, verbosity_level=0):
return do_assert


SOME_FLOAT = r"[+-]?([0-9]*[.])?[0-9]+\s*"
SOME_FLOAT = r"[+-]?((?:([0-9]*[.])?[0-9]+(e-?[0-9]+)?)|inf|nan)\s*"
SOME_INT = r"[0-9]+\s*"
SOME_TOLERANCE = rf"({SOME_FLOAT}|[+-]?[0-9]+(\.[0-9]+)?[eE][+-]?[0-9]+\s*)"


class TestApprox:
def test_error_messages_native_dtypes(self, assert_approx_raises_regex):
# Treat bool exactly.
assert_approx_raises_regex(
{"a": 1.0, "b": True},
{"a": 1.0, "b": False},
[
"",
" comparison failed. Mismatched elements: 1 / 2:",
f" Max absolute difference: {SOME_FLOAT}",
f" Max relative difference: {SOME_FLOAT}",
r" Index\s+\| Obtained\s+\| Expected",
r".*(True|False)\s+",
],
)
assert_approx_raises_regex(
2.0,
1.0,
Expand Down Expand Up @@ -596,6 +609,13 @@ def test_complex(self):
assert approx(x, rel=5e-6, abs=0) == a
assert approx(x, rel=5e-7, abs=0) != a

def test_expecting_bool(self) -> None:
assert True == approx(True) # noqa: E712
assert False == approx(False) # noqa: E712
assert True != approx(False) # noqa: E712
assert True != approx(False, abs=2) # noqa: E712
assert 1 != approx(True)

def test_list(self):
actual = [1 + 1e-7, 2 + 1e-8]
expected = [1, 2]
Expand Down Expand Up @@ -661,6 +681,7 @@ def test_dict_wrong_len(self):
def test_dict_nonnumeric(self):
assert {"a": 1.0, "b": None} == pytest.approx({"a": 1.0, "b": None})
assert {"a": 1.0, "b": 1} != pytest.approx({"a": 1.0, "b": None})
assert {"a": 1.0, "b": True} != pytest.approx({"a": 1.0, "b": False}, abs=2)

def test_dict_vs_other(self):
assert 1 != approx({"a": 0})
Expand Down

0 comments on commit a16e8ea

Please sign in to comment.