Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sparse merkle tree part1 #58

Closed
wants to merge 11 commits into from
72 changes: 72 additions & 0 deletions tests/test_sparse_merkle_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest

from hypothesis import (
given,
strategies as st,
settings,
)

from eth_hash.auto import (
keccak,
)

from trie.sparse_merkle_tree import (
SparseMerkleTree,
)
from trie.constants import (
EMPTY_NODE_HASHES,
)


@given(k=st.lists(st.binary(min_size=20, max_size=20), min_size=100, max_size=100, unique=True),
v=st.lists(st.binary(min_size=1), min_size=100, max_size=100),
chosen_numbers=st.lists(
st.integers(min_value=1, max_value=99),
min_size=50,
max_size=100,
unique=True),
random=st.randoms())
@settings(max_examples=10)
def test_sparse_merkle_tree(k, v, chosen_numbers, random):
kv_pairs = list(zip(k, v))

# Test basic get/set
trie = SparseMerkleTree(db={})
for k, v in kv_pairs:
assert not trie.exists(k)
trie.set(k, v)
for k, v in kv_pairs:
assert trie.get(k) == v
trie.delete(k)
for k, _ in kv_pairs:
assert not trie.exists(k)
assert trie.root_hash == keccak(EMPTY_NODE_HASHES[0] + EMPTY_NODE_HASHES[0])

# Test single update
random.shuffle(kv_pairs)
for k, v in kv_pairs:
trie.set(k, v)
prior_to_update_root = trie.root_hash
for i in chosen_numbers:
# Update
trie.set(kv_pairs[i][0], i.to_bytes(i, byteorder='big'))
assert trie.get(kv_pairs[i][0]) == i.to_bytes(i, byteorder='big')
assert trie.root_hash != prior_to_update_root
# Un-update
trie.set(kv_pairs[i][0], kv_pairs[i][1])
assert trie.root_hash == prior_to_update_root

# Test batch update with different update order
# First batch update
for i in chosen_numbers:
trie.set(kv_pairs[i][0], i.to_bytes(i, byteorder='big'))
batch_updated_root = trie.root_hash
# Un-update
for i in chosen_numbers:
trie.set(kv_pairs[i][0], kv_pairs[i][1])
assert trie.root_hash == prior_to_update_root
# Second batch update
random.shuffle(chosen_numbers)
for i in chosen_numbers:
trie.set(kv_pairs[i][0], i.to_bytes(i, byteorder='big'))
assert trie.root_hash == batch_updated_root
7 changes: 7 additions & 0 deletions trie/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,10 @@

BYTE_1 = bytes([1])
BYTE_0 = bytes([0])

# Constants for Sparse Merkle Tree
from eth_hash.auto import keccak
EMPTY_LEAF_NODE_HASH = BLANK_HASH
EMPTY_NODE_HASHES = [EMPTY_LEAF_NODE_HASH]
for _ in range(159):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As 160 (and 159) is used a couple of times, it would make sense to make it a constant. Maybe NODE_BIT_LENGTH?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree, what do you think about maybe something more straight forward like TREE_HEIGHT?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TREE_HEIGHT sounds good

EMPTY_NODE_HASHES.insert(0, keccak(EMPTY_NODE_HASHES[0] + EMPTY_NODE_HASHES[0]))
108 changes: 108 additions & 0 deletions trie/sparse_merkle_tree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
from eth_hash.auto import (
keccak,
)

from trie.constants import (
EMPTY_LEAF_NODE_HASH,
EMPTY_NODE_HASHES,
)
from trie.validation import (
validate_is_bytes,
validate_length,
)


# sanity check
assert EMPTY_LEAF_NODE_HASH == keccak(b'')


class SparseMerkleTree:
def __init__(self, db):
self.db = db
# Initialize an empty tree with one branch
self.root_hash = keccak(EMPTY_NODE_HASHES[0] + EMPTY_NODE_HASHES[0])
self.db[self.root_hash] = EMPTY_NODE_HASHES[0] + EMPTY_NODE_HASHES[0]
for i in range(159):
self.db[EMPTY_NODE_HASHES[i]] = EMPTY_NODE_HASHES[i+1] + EMPTY_NODE_HASHES[i+1]
self.db[EMPTY_LEAF_NODE_HASH] = b''

def get(self, key):
validate_is_bytes(key)
validate_length(key, 20)

target_bit = 1 << 159
path = int.from_bytes(key, byteorder='big')
node_hash = self.root_hash
for i in range(160):
if path & target_bit:
node_hash = self.db[node_hash][32:]
else:
node_hash = self.db[node_hash][:32]
target_bit >>= 1

if self.db[node_hash] is b'':
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reason for returning None instead of raising a KeyError? We've found that return value checking is way easier to forget than exception handling.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No particular reason. Just that in hexary trie, BLANK_NODE is returned for non-existent key. Raising an exception does make more sense.

else:
return self.db[node_hash]

def set(self, key, value):
validate_is_bytes(key)
validate_length(key, 20)
validate_is_bytes(value)

path = int.from_bytes(key, byteorder='big')
self.root_hash = self._set(value, path, 0, self.root_hash)
return

def _set(self, value, path, depth, node_hash):
if depth == 160:
return self._hash_and_save(value)
else:
node = self.db[node_hash]
target_bit = 1 << (159 - depth)
if (path & target_bit):
return self._hash_and_save(node[:32] + self._set(value, path, depth+1, node[32:]))
else:
return self._hash_and_save(self._set(value, path, depth+1, node[:32]) + node[32:])

def exists(self, key):
validate_is_bytes(key)
validate_length(key, 20)

return self.get(key) is not None

def delete(self, key):
"""
Equals to setting the value to None
"""
validate_is_bytes(key)
validate_length(key, 20)

self.set(key, b'')

#
# Utils
#
def _hash_and_save(self, node):
"""
Saves a node into the database and returns its hash
"""

node_hash = keccak(node)
self.db[node_hash] = node
return node_hash

#
# Dictionary API
#
def __getitem__(self, key):
return self.get(key)

def __setitem__(self, key, value):
return self.set(key, value)

def __delitem__(self, key):
return self.delete(key)

def __contains__(self, key):
return self.exists(key)