Skip to content

Commit

Permalink
Check bytes length with strict registry
Browse files Browse the repository at this point in the history
  • Loading branch information
kclowes committed Aug 14, 2019
1 parent e9284cd commit 602db5f
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 57 deletions.
7 changes: 6 additions & 1 deletion tests/core/contracts/test_contract_call_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,12 @@ def test_set_byte_array(arrays_contract, call, transact, args, expected):
assert result == expected


@pytest.mark.parametrize('args,expected', [([b''], [b'\x00']), (['0x'], [b'\x00'])])
@pytest.mark.parametrize(
'args,expected', [
([b'1'], [b'1']),
(['0xDe'], [b'\xDe'])
]
)
def test_set_strict_byte_array(strict_arrays_contract, call, transact, args, expected):
transact(
contract=strict_arrays_contract,
Expand Down
29 changes: 18 additions & 11 deletions tests/core/contracts/test_contract_method_to_argument_matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,20 +152,27 @@ def test_error_when_duplicate_match(web3):
Contract._find_matching_fn_abi('a', [100])


@pytest.mark.parametrize('arguments', (['0xf00b47'], [b''], [''], ['00' * 16]))
def test_strict_errors_if_type_is_wrong(web3_strict_types, arguments):
Contract = web3_strict_types.eth.contract(abi=MULTIPLE_FUNCTIONS)

with pytest.raises(ValidationError):
Contract._find_matching_fn_abi('a', arguments)


@pytest.mark.parametrize(
'arguments,expected_types,expected_error',
'arguments,expected_types',
(
# TODO (['0xf00b47'], ['bytes12'], TypeError),
([''], ['bytes'], ValidationError),
([], []),
([1234567890], ['uint256']),
([-1], ['int8']),
([[(-1, True), (2, False)]], ['(int256,bool)[]']),
)
)
def test_errors_if_type_is_wrong(
web3_strict_types,
arguments,
expected_types,
expected_error):

def test_strict_finds_function_with_matching_args(web3_strict_types, arguments, expected_types):
Contract = web3_strict_types.eth.contract(abi=MULTIPLE_FUNCTIONS)

with pytest.raises(expected_error):
Contract._find_matching_fn_abi('a', arguments)
abi = Contract._find_matching_fn_abi('a', arguments)
assert abi['name'] == 'a'
assert len(abi['inputs']) == len(expected_types)
assert set(get_abi_input_types(abi)) == set(expected_types)
6 changes: 3 additions & 3 deletions tests/core/utilities/test_abi_is_encodable.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,15 @@ def test_is_encodable(value, _type, expected):
(
# Special bytes<M> behavior
('12', 'bytes2', False), # no hex strings without leading 0x
('0x12', 'bytes2', True), # with 0x OK
('0x12', 'bytes1', True), # with 0x OK
('0123', 'bytes2', False), # needs a 0x
# TODO (b'\x12', 'bytes2', False), # no undersize bytes value
(b'\x12', 'bytes2', False), # no undersize bytes value
('0123', 'bytes1', False), # no oversize hex strings
('1', 'bytes2', False), # no odd length
('0x1', 'bytes2', False), # no odd length
# Special bytes behavior
('12', 'bytes', False),
('12', 'bytes', False), # has to have 0x if string
('0x12', 'bytes', True),
('1', 'bytes', False),
('0x1', 'bytes', False),
Expand Down
17 changes: 0 additions & 17 deletions tests/core/utilities/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,20 +147,3 @@ def test_validate_abi_value(abi_type, value, expected):
return

validate_abi_value(abi_type, value)


@pytest.mark.parametrize(
'abi_type,value,expected',
(
# ('bytes3', b'T\x02', TypeError), TODO - has to exactly match length
('bytes2', b'T\x02', None),
)
)
def test_validate_abi_value_strict(abi_type, value, expected):

if isinstance(expected, type) and issubclass(expected, Exception):
with pytest.raises(expected):
validate_abi_value(abi_type, value)
return

validate_abi_value(abi_type, value)
104 changes: 84 additions & 20 deletions web3/_utils/abi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,12 @@
from eth_abi import (
encoding,
)
from eth_abi.base import (
parse_type_str,
)
from eth_abi.exceptions import (
ValueOutOfBounds,
)
from eth_abi.grammar import (
ABIType,
TupleType,
Expand Down Expand Up @@ -129,19 +135,6 @@ def filter_by_argument_name(argument_names, contract_abi):
]


def length_of_bytes_type(abi_type):
if not is_bytes_type(abi_type):
raise ValueError(
f"Cannot parse length of nonbytes abi-type: {abi_type}"
)

byte_length = re.search('(\d{1,2})', abi_type)
if not byte_length:
return None
else:
return int(byte_length.group(0))


class AddressEncoder(encoding.AddressEncoder):
@classmethod
def validate_value(cls, value):
Expand Down Expand Up @@ -173,7 +166,6 @@ def validate_value(self, value):
class StrictAcceptsHexStrMixin:
def validate_value(self, value):
if is_text(value) and value[:2] != '0x':
print('in strict accepts hexstr mixin')
self.invalidate_value(
value,
msg='hex string must be prefixed with 0x'
Expand All @@ -194,15 +186,11 @@ class BytesEncoder(AcceptsHexStrMixin, encoding.BytesEncoder):
pass


class StrictBytesEncoder(StrictAcceptsHexStrMixin, encoding.PackedBytesEncoder):
pass


class ByteStringEncoder(AcceptsHexStrMixin, encoding.ByteStringEncoder):
pass


class StrictByteStringEncoder(StrictAcceptsHexStrMixin, encoding.PackedByteStringEncoder):
class StrictByteStringEncoder(StrictAcceptsHexStrMixin, encoding.ByteStringEncoder):
pass


Expand All @@ -221,6 +209,82 @@ def validate_value(cls, value):
super().validate_value(value)


class ExactLengthBytesEncoder(encoding.BaseEncoder):
is_big_endian = False
value_bit_size = None
data_byte_size = None
encode_fn = None

def validate(self):
super().validate()

if self.value_bit_size is None:
raise ValueError("`value_bit_size` may not be none")
if self.data_byte_size is None:
raise ValueError("`data_byte_size` may not be none")
if self.encode_fn is None:
raise ValueError("`encode_fn` may not be none")
if self.is_big_endian is None:
raise ValueError("`is_big_endian` may not be none")

if self.value_bit_size % 8 != 0:
raise ValueError(
"Invalid value bit size: {0}. Must be a multiple of 8".format(
self.value_bit_size,
)
)

if self.value_bit_size > self.data_byte_size * 8:
raise ValueError("Value byte size exceeds data size")

def encode(self, value):
self.validate_value(value)
return self.encode_fn(value)

def validate_value(self, value):
if not is_bytes(value) and not is_text(value):
self.invalidate_value(value)

if is_text(value) and value[:2] != '0x':
self.invalidate_value(
value,
msg='hex string must be prefixed with 0x'
)
elif is_text(value):
try:
value = decode_hex(value)
except binascii.Error:
self.invalidate_value(
value,
msg='invalid hex string',
)

byte_size = self.value_bit_size // 8
if len(value) > byte_size:
self.invalidate_value(
value,
exc=ValueOutOfBounds,
msg="exceeds total byte size for bytes{} encoding".format(byte_size),
)
elif len(value) < byte_size:
self.invalidate_value(
value,
exc=ValueOutOfBounds,
msg="less than total byte size for bytes{} encoding".format(byte_size),
)

@staticmethod
def encode_fn(value):
return value

@parse_type_str('bytes')
def from_type_str(cls, abi_type, registry):
return cls(
value_bit_size=abi_type.sub * 8,
data_byte_size=abi_type.sub,
)


def filter_by_encodability(w3, args, kwargs, contract_abi):
return [
function_abi
Expand All @@ -245,7 +309,7 @@ def check_if_arguments_can_be_encoded(function_abi, w3, args, kwargs):
return False

return all(
w3.codec.is_encodable(_type, arg)
w3.is_encodable(_type, arg)
for _type, arg in zip(types, aligned_args)
)

Expand Down
3 changes: 1 addition & 2 deletions web3/_utils/filters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from eth_abi import (
decode_abi,
is_encodable,
)
from eth_abi.grammar import (
parse as parse_type_string,
Expand Down Expand Up @@ -218,7 +217,7 @@ def match_fn(match_values_and_abi, data):
continue
normalized_data = normalize_data_values(abi_type, data_value)
for value in match_values:
if not is_encodable(abi_type, value):
if not w3.is_encodable(abi_type, value):
raise ValueError(
"Value {0} is of the wrong abi type. "
"Expected {1} typed value.".format(value, abi_type))
Expand Down
5 changes: 2 additions & 3 deletions web3/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
AddressEncoder,
BytesEncoder,
ByteStringEncoder,
StrictBytesEncoder,
ExactLengthBytesEncoder,
StrictByteStringEncoder,
TextStringEncoder,
)
Expand Down Expand Up @@ -161,7 +161,6 @@ def __init__(self, provider=None, middlewares=None, modules=None, ens=empty):
self.ens = ens

def is_encodable(self, _type, value):
# TODO - there is probably a better place to put this
return self.codec.is_encodable(_type, value)

def build_default_registry(self):
Expand Down Expand Up @@ -315,7 +314,7 @@ def build_strict_registry(self):
)
registry.register(
BaseEquals('bytes', with_sub=True),
StrictBytesEncoder, decoding.BytesDecoder,
ExactLengthBytesEncoder, decoding.BytesDecoder,
label='bytes<M>',
)
registry.register(
Expand Down

0 comments on commit 602db5f

Please sign in to comment.