-
Notifications
You must be signed in to change notification settings - Fork 1k
/
Copy pathminimal_ssz.py
331 lines (279 loc) · 11 KB
/
minimal_ssz.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
from typing import Any
from .hash_function import hash
BYTES_PER_CHUNK = 32
BYTES_PER_LENGTH_OFFSET = 4
ZERO_CHUNK = b'\x00' * BYTES_PER_CHUNK
def SSZType(fields):
class SSZObject():
def __init__(self, **kwargs):
for f, t in fields.items():
if f not in kwargs:
setattr(self, f, get_zero_value(t))
else:
setattr(self, f, kwargs[f])
def __eq__(self, other):
return self.fields == other.fields and self.serialize() == other.serialize()
def __hash__(self):
return int.from_bytes(self.hash_tree_root(), byteorder="little")
def __str__(self):
output = []
for field in self.fields:
output.append(f'{field}: {getattr(self, field)}')
return "\n".join(output)
def serialize(self):
return serialize_value(self, self.__class__)
def hash_tree_root(self):
return hash_tree_root(self, self.__class__)
SSZObject.fields = fields
return SSZObject
class Vector():
def __init__(self, items):
self.items = items
self.length = len(items)
def __getitem__(self, key):
return self.items[key]
def __setitem__(self, key, value):
self.items[key] = value
def __iter__(self):
return iter(self.items)
def __len__(self):
return self.length
def is_basic(typ):
# if not a string, it is a complex, and cannot be basic
if not isinstance(typ, str):
return False
# "uintN": N-bit unsigned integer (where N in [8, 16, 32, 64, 128, 256])
elif typ[:4] == 'uint' and typ[4:] in ['8', '16', '32', '64', '128', '256']:
return True
# "bool": True or False
elif typ == 'bool':
return True
# alias: "byte" -> "uint8"
elif typ == 'byte':
return True
# default
else:
return False
def is_constant_sized(typ):
# basic objects are fixed size by definition
if is_basic(typ):
return True
# dynamic size array type, "list": [elem_type].
# Not constant size by definition.
elif isinstance(typ, list) and len(typ) == 1:
return False
# fixed size array type, "vector": [elem_type, length]
# Constant size, but only if the elements are.
elif isinstance(typ, list) and len(typ) == 2:
return is_constant_sized(typ[0])
# bytes array (fixed or dynamic size)
elif isinstance(typ, str) and typ[:5] == 'bytes':
# if no length suffix, it has a dynamic size
return typ != 'bytes'
# containers are only constant-size if all of the fields are constant size.
elif hasattr(typ, 'fields'):
for subtype in typ.fields.values():
if not is_constant_sized(subtype):
return False
return True
else:
raise Exception("Type not recognized")
def coerce_to_bytes(x):
if isinstance(x, str):
o = x.encode('utf-8')
assert len(o) == len(x)
return o
elif isinstance(x, bytes):
return x
else:
raise Exception("Expecting bytes")
def encode_series(values, types):
# Recursively serialize
parts = [(is_constant_sized(types[i]), serialize_value(values[i], types[i])) for i in range(len(values))]
# Compute and check lengths
fixed_lengths = [len(serialized) if constant_size else BYTES_PER_LENGTH_OFFSET
for (constant_size, serialized) in parts]
variable_lengths = [len(serialized) if not constant_size else 0
for (constant_size, serialized) in parts]
# Check if integer is not out of bounds (Python)
assert sum(fixed_lengths + variable_lengths) < 2 ** (BYTES_PER_LENGTH_OFFSET * 8)
# Interleave offsets of variable-size parts with fixed-size parts.
# Avoid quadratic complexity in calculation of offsets.
offset = sum(fixed_lengths)
variable_parts = []
fixed_parts = []
for (constant_size, serialized) in parts:
if constant_size:
fixed_parts.append(serialized)
else:
fixed_parts.append(offset.to_bytes(BYTES_PER_LENGTH_OFFSET, 'little'))
variable_parts.append(serialized)
offset += len(serialized)
# Return the concatenation of the fixed-size parts (offsets interleaved) with the variable-size parts
return b"".join(fixed_parts + variable_parts)
def serialize_value(value, typ=None):
if typ is None:
typ = infer_type(value)
# "uintN"
if isinstance(typ, str) and typ[:4] == 'uint':
length = int(typ[4:])
assert length in (8, 16, 32, 64, 128, 256)
return value.to_bytes(length // 8, 'little')
# "bool"
elif isinstance(typ, str) and typ == 'bool':
assert value in (True, False)
return b'\x01' if value is True else b'\x00'
# Vector
elif isinstance(typ, list) and len(typ) == 2:
# (regardless of element type, sanity-check if the length reported in the vector type matches the value length)
assert len(value) == typ[1]
return encode_series(value, [typ[0]] * len(value))
# List
elif isinstance(typ, list) and len(typ) == 1:
return encode_series(value, [typ[0]] * len(value))
# "bytes" (variable size)
elif isinstance(typ, str) and typ == 'bytes':
return coerce_to_bytes(value)
# "bytesN" (fixed size)
elif isinstance(typ, str) and len(typ) > 5 and typ[:5] == 'bytes':
assert len(value) == int(typ[5:]), (value, int(typ[5:]))
return coerce_to_bytes(value)
# containers
elif hasattr(typ, 'fields'):
values = [getattr(value, field) for field in typ.fields.keys()]
types = list(typ.fields.values())
return encode_series(values, types)
else:
print(value, typ)
raise Exception("Type not recognized")
def get_zero_value(typ: Any) -> Any:
if isinstance(typ, str):
# Bytes array
if typ == 'bytes':
return b''
# bytesN
elif typ[:5] == 'bytes' and len(typ) > 5:
length = int(typ[5:])
return b'\x00' * length
# Basic types
elif typ == 'bool':
return False
elif typ[:4] == 'uint':
return 0
elif typ == 'byte':
return 0x00
else:
raise ValueError("Type not recognized")
# Vector:
elif isinstance(typ, list) and len(typ) == 2:
return [get_zero_value(typ[0]) for _ in range(typ[1])]
# List:
elif isinstance(typ, list) and len(typ) == 1:
return []
# Container:
elif hasattr(typ, 'fields'):
return typ(**{field: get_zero_value(subtype) for field, subtype in typ.fields.items()})
else:
print(typ)
raise Exception("Type not recognized")
def chunkify(bytez):
bytez += b'\x00' * (-len(bytez) % BYTES_PER_CHUNK)
return [bytez[i:i + 32] for i in range(0, len(bytez), 32)]
def pack(values, subtype):
return chunkify(b''.join([serialize_value(value, subtype) for value in values]))
def is_power_of_two(x):
return x > 0 and x & (x - 1) == 0
def merkleize(chunks):
tree = chunks[::]
while not is_power_of_two(len(tree)):
tree.append(ZERO_CHUNK)
tree = [ZERO_CHUNK] * len(tree) + tree
for i in range(len(tree) // 2 - 1, 0, -1):
tree[i] = hash(tree[i * 2] + tree[i * 2 + 1])
return tree[1]
def mix_in_length(root, length):
return hash(root + length.to_bytes(32, 'little'))
def infer_type(value):
"""
Note: defaults to uint64 for integer type inference due to lack of information.
Other integer sizes are still supported, see spec.
:param value: The value to infer a SSZ type for.
:return: The SSZ type.
"""
if hasattr(value.__class__, 'fields'):
return value.__class__
elif isinstance(value, Vector):
if len(value) > 0:
return [infer_type(value[0]), len(value)]
else:
# Element type does not matter too much,
# assumed to be a basic type for size-encoding purposes, vector is empty.
return ['uint64']
elif isinstance(value, list):
if len(value) > 0:
return [infer_type(value[0])]
else:
# Element type does not matter, list-content size will be encoded regardless, list is empty.
return ['uint64']
elif isinstance(value, (bytes, str)):
return 'bytes'
elif isinstance(value, int):
return 'uint64'
else:
raise Exception("Failed to infer type")
def hash_tree_root(value, typ=None):
if typ is None:
typ = infer_type(value)
# -------------------------------------
# merkleize(pack(value))
# basic object: merkleize packed version (merkleization pads it to 32 bytes if it is not already)
if is_basic(typ):
return merkleize(pack([value], typ))
# or a vector of basic objects
elif isinstance(typ, list) and len(typ) == 2 and is_basic(typ[0]):
assert len(value) == typ[1]
return merkleize(pack(value, typ[0]))
# -------------------------------------
# mix_in_length(merkleize(pack(value)), len(value))
# if value is a list of basic objects
elif isinstance(typ, list) and len(typ) == 1 and is_basic(typ[0]):
return mix_in_length(merkleize(pack(value, typ[0])), len(value))
# (needs some extra work for non-fixed-sized bytes array)
elif typ == 'bytes':
return mix_in_length(merkleize(chunkify(coerce_to_bytes(value))), len(value))
# -------------------------------------
# merkleize([hash_tree_root(element) for element in value])
# if value is a vector of composite objects
elif isinstance(typ, list) and len(typ) == 2 and not is_basic(typ[0]):
return merkleize([hash_tree_root(element, typ[0]) for element in value])
# (needs some extra work for fixed-sized bytes array)
elif isinstance(typ, str) and typ[:5] == 'bytes' and len(typ) > 5:
assert len(value) == int(typ[5:])
return merkleize(chunkify(coerce_to_bytes(value)))
# or a container
elif hasattr(typ, 'fields'):
return merkleize([hash_tree_root(getattr(value, field), subtype) for field, subtype in typ.fields.items()])
# -------------------------------------
# mix_in_length(merkleize([hash_tree_root(element) for element in value]), len(value))
# if value is a list of composite objects
elif isinstance(typ, list) and len(typ) == 1 and not is_basic(typ[0]):
return mix_in_length(merkleize([hash_tree_root(element, typ[0]) for element in value]), len(value))
# -------------------------------------
else:
raise Exception("Type not recognized")
def truncate(container):
field_keys = list(container.fields.keys())
truncated_fields = {
key: container.fields[key]
for key in field_keys[:-1]
}
truncated_class = SSZType(truncated_fields)
kwargs = {
field: getattr(container, field)
for field in field_keys[:-1]
}
return truncated_class(**kwargs)
def signing_root(container):
return hash_tree_root(truncate(container))
def serialize(ssz_object):
return getattr(ssz_object, 'serialize')()