Skip to content

Commit

Permalink
fix(fw): Fix FixedSizeBytes != comparison (#477)
Browse files Browse the repository at this point in the history
* fix(fw): fixed size bytes != comparison

* changelog

* fix: simplify changes
  • Loading branch information
marioevz authored Mar 15, 2024
1 parent ba1efef commit f8c435e
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 2 deletions.
2 changes: 2 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ Test fixtures for use by clients are available for each release on the [Github r

### 🛠️ Framework

- 🐞 Fix incorrect `!=` operator for `FixedSizeBytes` ([#477](https://github.com/ethereum/execution-spec-tests/pull/477)).

### 🔧 EVM Tools

### 📋 Misc
Expand Down
9 changes: 7 additions & 2 deletions src/ethereum_test_tools/common/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
Basic type primitives used to define other types.
"""


from typing import ClassVar, SupportsBytes, Type, TypeVar

from .conversions import (
Expand Down Expand Up @@ -174,7 +173,7 @@ def or_none(cls: Type[T], input: T | FixedSizeBytesConvertible | None) -> T | No

def __eq__(self, other: object) -> bool:
"""
Compares two FixedSizeBytes objects.
Compares two FixedSizeBytes objects to be equal.
"""
if not isinstance(other, FixedSizeBytes):
assert (
Expand All @@ -186,6 +185,12 @@ def __eq__(self, other: object) -> bool:
other = self.__class__(other)
return super().__eq__(other)

def __ne__(self, other: object) -> bool:
"""
Compares two FixedSizeBytes objects to be not equal.
"""
return not self.__eq__(other)


class Address(FixedSizeBytes[20]): # type: ignore
"""
Expand Down
54 changes: 54 additions & 0 deletions src/ethereum_test_tools/tests/test_base_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
"""
Test suite for `ethereum_test` module base types.
"""

from typing import Any

import pytest

from ..common.base_types import Address, Hash


@pytest.mark.parametrize(
"a, b, equal",
[
(Address("0x0"), Address("0x0"), True),
(Address("0x0"), Address("0x1"), False),
(Address("0x1"), Address("0x0"), False),
(Address("0x1"), "0x1", True),
(Address("0x1"), "0x2", False),
(Address("0x1"), 1, True),
(Address("0x1"), 2, False),
(Address("0x1"), b"\x01", True),
(Address("0x1"), b"\x02", False),
("0x1", Address("0x1"), True),
("0x2", Address("0x1"), False),
(1, Address("0x1"), True),
(2, Address("0x1"), False),
(b"\x01", Address("0x1"), True),
(b"\x02", Address("0x1"), False),
(Hash("0x0"), Hash("0x0"), True),
(Hash("0x0"), Hash("0x1"), False),
(Hash("0x1"), Hash("0x0"), False),
(Hash("0x1"), "0x1", True),
(Hash("0x1"), "0x2", False),
(Hash("0x1"), 1, True),
(Hash("0x1"), 2, False),
(Hash("0x1"), b"\x01", True),
(Hash("0x1"), b"\x02", False),
("0x1", Hash("0x1"), True),
("0x2", Hash("0x1"), False),
(1, Hash("0x1"), True),
(2, Hash("0x1"), False),
],
)
def test_comparisons(a: Any, b: Any, equal: bool):
"""
Test the comparison methods of the base types.
"""
if equal:
assert a == b
assert not a != b
else:
assert a != b
assert not a == b

0 comments on commit f8c435e

Please sign in to comment.