Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validate field type on serializing #208

Merged
merged 2 commits into from
Apr 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 59 additions & 10 deletions pycardano/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,18 @@
from decimal import Decimal
from functools import wraps
from inspect import isclass
from typing import Any, Callable, List, Optional, Type, TypeVar, Union, get_type_hints
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Type,
TypeVar,
Union,
get_type_hints,
)

from cbor2 import CBOREncoder, CBORSimpleValue, CBORTag, dumps, loads, undefined
from pprintpp import pformat
Expand Down Expand Up @@ -254,7 +265,45 @@ def validate(self):
Raises:
InvalidDataException: When the data is invalid.
"""
pass
type_hints = get_type_hints(self.__class__)

def _check_recursive(value, type_hint):
if type_hint is Any:
return True
origin = getattr(type_hint, "__origin__", None)
if origin is None:
if isinstance(value, CBORSerializable):
value.validate()
return isinstance(value, type_hint)
elif origin is ClassVar:
return _check_recursive(value, type_hint.__args__[0])
elif origin is Union:
return any(_check_recursive(value, arg) for arg in type_hint.__args__)
elif origin is Dict or isinstance(value, dict):
key_type, value_type = type_hint.__args__
return all(
_check_recursive(k, key_type) and _check_recursive(v, value_type)
for k, v in value.items()
)
elif origin in (list, set, tuple):
if value is None:
return True
args = type_hint.__args__
if len(args) == 1:
return all(_check_recursive(item, args[0]) for item in value)
elif len(args) > 1:
return all(
_check_recursive(item, arg) for item, arg in zip(value, args)
)
return True # We don't know how to check this type

for field_name, field_type in type_hints.items():
field_value = getattr(self, field_name)
if not _check_recursive(field_value, field_type):
raise TypeError(
f"Field '{field_name}' should be of type {field_type}, "
f"got {repr(field_value)} instead."
)

def to_validated_primitive(self) -> Primitive:
"""Convert the instance and its elements to CBOR primitives recursively with data validated by :meth:`validate`
Expand Down Expand Up @@ -505,8 +554,8 @@ class ArrayCBORSerializable(CBORSerializable):
>>> t = Test2(c="c", test1=Test1(a="a"))
>>> t
Test2(c='c', test1=Test1(a='a', b=None))
>>> cbor_hex = t.to_cbor()
>>> cbor_hex
>>> cbor_hex = t.to_cbor() # doctest: +SKIP
>>> cbor_hex # doctest: +SKIP
'826163826161f6'
>>> Test2.from_cbor(cbor_hex) # doctest: +SKIP
Test2(c='c', test1=Test1(a='a', b=None))
Expand Down Expand Up @@ -534,8 +583,8 @@ class ArrayCBORSerializable(CBORSerializable):
Test2(c='c', test1=Test1(a='a', b=None))
>>> t.to_primitive() # Notice below that attribute "b" is not included in converted primitive.
['c', ['a']]
>>> cbor_hex = t.to_cbor()
>>> cbor_hex
>>> cbor_hex = t.to_cbor() # doctest: +SKIP
>>> cbor_hex # doctest: +SKIP
'826163816161'
>>> Test2.from_cbor(cbor_hex) # doctest: +SKIP
Test2(c='c', test1=Test1(a='a', b=None))
Expand Down Expand Up @@ -621,8 +670,8 @@ class MapCBORSerializable(CBORSerializable):
Test2(c=None, test1=Test1(a='a', b=''))
>>> t.to_primitive()
{'c': None, 'test1': {'a': 'a', 'b': ''}}
>>> cbor_hex = t.to_cbor()
>>> cbor_hex
>>> cbor_hex = t.to_cbor() # doctest: +SKIP
>>> cbor_hex # doctest: +SKIP
'a26163f6657465737431a261616161616260'
>>> Test2.from_cbor(cbor_hex) # doctest: +SKIP
Test2(c=None, test1=Test1(a='a', b=''))
Expand All @@ -645,8 +694,8 @@ class MapCBORSerializable(CBORSerializable):
Test2(c=None, test1=Test1(a='a', b=''))
>>> t.to_primitive()
{'1': {'0': 'a', '1': ''}}
>>> cbor_hex = t.to_cbor()
>>> cbor_hex
>>> cbor_hex = t.to_cbor() # doctest: +SKIP
>>> cbor_hex # doctest: +SKIP
'a16131a261306161613160'
>>> Test2.from_cbor(cbor_hex) # doctest: +SKIP
Test2(c=None, test1=Test1(a='a', b=''))
Expand Down
1 change: 1 addition & 0 deletions pycardano/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def __post_init__(self):
self.amount = Value(self.amount)

def validate(self):
super().validate()
if isinstance(self.amount, int) and self.amount < 0:
raise InvalidDataException(
f"Transaction output cannot have negative amount of ADA or "
Expand Down
5 changes: 4 additions & 1 deletion test/pycardano/test_certificate.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pycardano.address import Address
from pycardano.certificate import (
PoolKeyHash,
StakeCredential,
StakeDelegation,
StakeDeregistration,
Expand Down Expand Up @@ -43,7 +44,9 @@ def test_stake_deregistration():

def test_stake_delegation():
stake_credential = StakeCredential(TEST_ADDR.staking_part)
stake_delegation = StakeDelegation(stake_credential, b"1" * POOL_KEY_HASH_SIZE)
stake_delegation = StakeDelegation(
stake_credential, PoolKeyHash(b"1" * POOL_KEY_HASH_SIZE)
)

assert (
stake_delegation.to_cbor()
Expand Down
86 changes: 83 additions & 3 deletions test/pycardano/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from test.pycardano.util import check_two_way_cbor
from typing import Any, Union
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import pytest

Expand Down Expand Up @@ -67,7 +67,7 @@ def test_array_cbor_serializable_optional_field():
@dataclass
class Test1(ArrayCBORSerializable):
a: str
b: str = field(default=None, metadata={"optional": True})
b: Optional[str] = field(default=None, metadata={"optional": True})

@dataclass
class Test2(ArrayCBORSerializable):
Expand Down Expand Up @@ -104,7 +104,7 @@ class Test1(MapCBORSerializable):

@dataclass
class Test2(MapCBORSerializable):
c: str = field(default=None, metadata={"key": "0", "optional": True})
c: Optional[str] = field(default=None, metadata={"key": "0", "optional": True})
test1: Test1 = field(default_factory=Test1, metadata={"key": "1"})

t = Test2(test1=Test1(a="a"))
Expand Down Expand Up @@ -172,3 +172,83 @@ class Test1(MapCBORSerializable):
t = Test1(a="a", b=1)

check_two_way_cbor(t)


def test_wrong_primitive_type():
@dataclass
class Test1(MapCBORSerializable):
a: str = ""

with pytest.raises(TypeError):
Test1(a=1).to_cbor()


def test_wrong_union_type():
@dataclass
class Test1(MapCBORSerializable):
a: Union[str, int] = ""

with pytest.raises(TypeError):
Test1(a=1.0).to_cbor()


def test_wrong_optional_type():
@dataclass
class Test1(MapCBORSerializable):
a: Optional[str] = ""

with pytest.raises(TypeError):
Test1(a=1.0).to_cbor()


def test_wrong_list_type():
@dataclass
class Test1(MapCBORSerializable):
a: List[str] = ""

with pytest.raises(TypeError):
Test1(a=[1]).to_cbor()


def test_wrong_dict_type():
@dataclass
class Test1(MapCBORSerializable):
a: Dict[str, int] = ""

with pytest.raises(TypeError):
Test1(a={1: 1}).to_cbor()


def test_wrong_tuple_type():
@dataclass
class Test1(MapCBORSerializable):
a: Tuple[str, int] = ""

with pytest.raises(TypeError):
Test1(a=(1, 1)).to_cbor()


def test_wrong_set_type():
@dataclass
class Test1(MapCBORSerializable):
a: Set[str] = ""

with pytest.raises(TypeError):
Test1(a={1}).to_cbor()


def test_wrong_nested_type():
@dataclass
class Test1(MapCBORSerializable):
a: str = ""

@dataclass
class Test2(MapCBORSerializable):
a: Test1 = ""
b: Optional[Test1] = None

with pytest.raises(TypeError):
Test2(a=1).to_cbor()

with pytest.raises(TypeError):
Test2(a=Test1(a=1)).to_cbor()