Skip to content

Commit

Permalink
use chia_rs provided KeyId, ValueId, TreeIndex (#19214)
Browse files Browse the repository at this point in the history
<!-- Merging Requirements:
- Please give your PR a title that is release-note friendly
- In order to be merged, you must add the most appropriate category
Label (Added, Changed, Fixed) to your PR
-->
<!-- Explain why this is an improvement (Does this add missing
functionality, improve performance, or reduce complexity?) -->

### Purpose:

<!-- Does this PR introduce a breaking change? -->

### Current Behavior:

### New Behavior:

<!-- As we aim for complete code coverage, please include details
regarding unit, and regression tests -->

### Testing Notes:

<!-- Attach any visual examples, or supporting evidence (attach any
.gif/video/console output below) -->
  • Loading branch information
altendky authored Feb 7, 2025
2 parents f83410e + 7d9005d commit 6a86a3d
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 96 deletions.
3 changes: 2 additions & 1 deletion chia/_tests/core/data_layer/test_data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import aiohttp
import pytest
from chia_rs.datalayer import TreeIndex

from chia._tests.core.data_layer.util import Example, add_0123_example, add_01234567_example
from chia._tests.util.misc import BenchmarkRunner, Marks, boolean_datacases, datacases
Expand All @@ -41,7 +42,7 @@
from chia.data_layer.data_store import DataStore
from chia.data_layer.download_data import insert_from_delta_file, write_files_for_root
from chia.data_layer.util.benchmark import generate_datastore
from chia.data_layer.util.merkle_blob import MerkleBlob, RawInternalMerkleNode, RawLeafMerkleNode, TreeIndex
from chia.data_layer.util.merkle_blob import MerkleBlob, RawInternalMerkleNode, RawLeafMerkleNode
from chia.types.blockchain_format.program import Program
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.byte_types import hexstr_to_bytes
Expand Down
60 changes: 30 additions & 30 deletions chia/_tests/core/data_layer/test_merkle_blob.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import hashlib
import itertools
import re
from dataclasses import dataclass
from random import Random
from typing import Generic, Protocol, TypeVar, final
Expand All @@ -11,31 +12,30 @@

# TODO: update after resolution in https://github.com/pytest-dev/pytest/issues/7469
from _pytest.fixtures import SubRequest
from chia_rs.datalayer import KeyId, TreeIndex, ValueId

from chia._tests.util.misc import DataCase, Marks, datacases
from chia.data_layer.data_layer_util import InternalNode, Side, internal_hash
from chia.data_layer.util.merkle_blob import (
InvalidIndexError,
KeyId,
KeyOrValueId,
MerkleBlob,
NodeMetadata,
NodeType,
RawInternalMerkleNode,
RawLeafMerkleNode,
RawMerkleNodeProtocol,
TreeIndex,
ValueId,
data_size,
metadata_size,
pack_raw_node,
raw_node_classes,
raw_node_type_to_class,
spacing,
undefined_index,
unpack_raw_node,
)
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.ints import int64, uint32
from chia.util.ints import int64

pytestmark = pytest.mark.data_layer

Expand Down Expand Up @@ -125,16 +125,16 @@ def id(self) -> str:
RawNodeFromBlobCase(
raw=RawInternalMerkleNode(
hash=bytes32(range(32)),
parent=TreeIndex(uint32(0x20212223)),
left=TreeIndex(uint32(0x24252627)),
right=TreeIndex(uint32(0x28292A2B)),
parent=TreeIndex(0x20212223),
left=TreeIndex(0x24252627),
right=TreeIndex(0x28292A2B),
),
packed=internal_reference_blob,
),
RawNodeFromBlobCase(
raw=RawLeafMerkleNode(
hash=bytes32(range(32)),
parent=TreeIndex(uint32(0x20212223)),
parent=TreeIndex(0x20212223),
key=KeyId(KeyOrValueId(int64(0x2425262728292A2B))),
value=ValueId(KeyOrValueId(int64(0x2C2D2E2F30313233))),
),
Expand All @@ -146,7 +146,7 @@ def id(self) -> str:
@datacases(*reference_raw_nodes)
def test_raw_node_from_blob(case: RawNodeFromBlobCase[RawMerkleNodeProtocol]) -> None:
node = unpack_raw_node(
index=TreeIndex(uint32(0)),
index=TreeIndex(0),
metadata=NodeMetadata(type=case.raw.type, dirty=False),
data=case.packed,
)
Expand All @@ -171,7 +171,7 @@ def test_merkle_blob_one_leaf_loads() -> None:
blob = bytearray(bytes(NodeMetadata(type=NodeType.leaf, dirty=False)) + pack_raw_node(leaf))

merkle_blob = MerkleBlob(blob=blob)
assert merkle_blob.get_raw_node(TreeIndex(uint32(0))) == leaf
assert merkle_blob.get_raw_node(TreeIndex(0)) == leaf


def test_merkle_blob_two_leafs_loads() -> None:
Expand All @@ -180,18 +180,18 @@ def test_merkle_blob_two_leafs_loads() -> None:
root = RawInternalMerkleNode(
hash=bytes32(range(32)),
parent=None,
left=TreeIndex(uint32(1)),
right=TreeIndex(uint32(2)),
left=TreeIndex(1),
right=TreeIndex(2),
)
left_leaf = RawLeafMerkleNode(
hash=bytes32(range(32)),
parent=TreeIndex(uint32(0)),
parent=TreeIndex(0),
key=KeyId(KeyOrValueId(int64(0x0405060708090A0B))),
value=ValueId(KeyOrValueId(int64(0x0405060708090A1B))),
)
right_leaf = RawLeafMerkleNode(
hash=bytes32(range(32)),
parent=TreeIndex(uint32(0)),
parent=TreeIndex(0),
key=KeyId(KeyOrValueId(int64(0x1415161718191A1B))),
value=ValueId(KeyOrValueId(int64(0x1415161718191A2B))),
)
Expand All @@ -201,20 +201,20 @@ def test_merkle_blob_two_leafs_loads() -> None:
blob.extend(bytes(NodeMetadata(type=NodeType.leaf, dirty=False)) + pack_raw_node(right_leaf))

merkle_blob = MerkleBlob(blob=blob)
assert merkle_blob.get_raw_node(TreeIndex(uint32(0))) == root
assert merkle_blob.get_raw_node(TreeIndex(root.left)) == left_leaf
assert merkle_blob.get_raw_node(TreeIndex(root.right)) == right_leaf
assert merkle_blob.get_raw_node(TreeIndex(0)) == root
assert merkle_blob.get_raw_node(root.left) == left_leaf
assert merkle_blob.get_raw_node(root.right) == right_leaf
assert left_leaf.parent is not None
assert merkle_blob.get_raw_node(TreeIndex(left_leaf.parent)) == root
assert merkle_blob.get_raw_node(left_leaf.parent) == root
assert right_leaf.parent is not None
assert merkle_blob.get_raw_node(TreeIndex(right_leaf.parent)) == root
assert merkle_blob.get_raw_node(right_leaf.parent) == root

assert merkle_blob.get_lineage_with_indexes(TreeIndex(uint32(0))) == [(0, root)]
assert merkle_blob.get_lineage_with_indexes(TreeIndex(0)) == [(TreeIndex(0), root)]
expected: list[tuple[TreeIndex, RawMerkleNodeProtocol]] = [
(TreeIndex(uint32(1)), left_leaf),
(TreeIndex(uint32(0)), root),
(TreeIndex(1), left_leaf),
(TreeIndex(0), root),
]
assert merkle_blob.get_lineage_with_indexes(TreeIndex(root.left)) == expected
assert merkle_blob.get_lineage_with_indexes(root.left) == expected

merkle_blob.calculate_lazy_hashes()
son_hash = bytes32(range(32))
Expand Down Expand Up @@ -270,7 +270,7 @@ def test_insert_delete_loads_all_keys() -> None:
hash = generate_hash(seed)
merkle_blob.insert(key, value, hash)
key_index = merkle_blob.key_to_index[key]
lineage = merkle_blob.get_lineage_with_indexes(TreeIndex(key_index))
lineage = merkle_blob.get_lineage_with_indexes(key_index)
assert len(lineage) <= max_height
keys_values[key] = value
if current_num_entries == 0:
Expand All @@ -289,7 +289,7 @@ def test_insert_delete_loads_all_keys() -> None:
hash = generate_hash(seed)
merkle_blob_2.upsert(key, value, hash)
key_index = merkle_blob_2.key_to_index[key]
lineage = merkle_blob_2.get_lineage_with_indexes(TreeIndex(key_index))
lineage = merkle_blob_2.get_lineage_with_indexes(key_index)
assert len(lineage) <= max_height
keys_values[key] = value
assert merkle_blob_2.get_keys_values() == keys_values
Expand Down Expand Up @@ -362,7 +362,7 @@ def test_proof_of_inclusion_merkle_blob() -> None:
del keys_values[kv_id]

for kv_id in delete_ordering:
with pytest.raises(Exception, match=f"Key {kv_id} not present in the store"):
with pytest.raises(Exception, match=f"Key {re.escape(str(kv_id))} not present in the store"):
merkle_blob.get_proof_of_inclusion(kv_id)

new_keys_values: dict[KeyId, ValueId] = {}
Expand All @@ -381,7 +381,7 @@ def test_proof_of_inclusion_merkle_blob() -> None:
assert proof_of_inclusion.valid()


@pytest.mark.parametrize(argnames="index", argvalues=[-1, 1, None])
@pytest.mark.parametrize(argnames="index", argvalues=[TreeIndex(1), undefined_index])
def test_get_raw_node_raises_for_invalid_indexes(index: TreeIndex) -> None:
merkle_blob = MerkleBlob(blob=bytearray())
merkle_blob.insert(
Expand Down Expand Up @@ -412,7 +412,7 @@ def test_helper_methods(merkle_blob_type: MerkleBlobCallable) -> None:
merkle_blob.insert(key, value, hash)
assert not merkle_blob.empty()
assert merkle_blob.get_root_hash() is not None
assert merkle_blob.get_root_hash() == merkle_blob.get_hash_at_index(TreeIndex(uint32(0)))
assert merkle_blob.get_root_hash() == merkle_blob.get_hash_at_index(TreeIndex(0))

merkle_blob.delete(key)
assert merkle_blob.empty()
Expand Down Expand Up @@ -469,8 +469,8 @@ def test_get_nodes(merkle_blob_type: MerkleBlobCallable) -> None:
all_nodes = merkle_blob.get_nodes_with_indexes()
for index, node in all_nodes:
if isinstance(node, (RawInternalMerkleNode, chia_rs.datalayer.InternalNode)):
left = merkle_blob.get_raw_node(TreeIndex(node.left))
right = merkle_blob.get_raw_node(TreeIndex(node.right))
left = merkle_blob.get_raw_node(node.left)
right = merkle_blob.get_raw_node(node.right)
assert left.parent == index
assert right.parent == index
assert bytes32(node.hash) == internal_hash(bytes32(left.hash), bytes32(right.hash))
Expand Down
22 changes: 10 additions & 12 deletions chia/data_layer/data_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from typing import Any, BinaryIO, Callable, Optional, Union

import aiosqlite
from chia_rs.datalayer import KeyId, TreeIndex, ValueId

from chia.data_layer.data_layer_errors import KeyNotFoundError, MerkleBlobNotFoundError, TreeGenerationIncrementingError
from chia.data_layer.data_layer_util import (
Expand Down Expand Up @@ -42,13 +43,10 @@
unspecified,
)
from chia.data_layer.util.merkle_blob import (
KeyId,
KeyOrValueId,
MerkleBlob,
RawInternalMerkleNode,
RawLeafMerkleNode,
TreeIndex,
ValueId,
)
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.batches import to_batches
Expand Down Expand Up @@ -232,7 +230,7 @@ async def insert_into_data_store_from_file(
terminal_nodes[node_hash] = (kid, vid)

missing_hashes: list[bytes32] = []
merkle_blob_queries: dict[bytes32, list[int]] = defaultdict(list)
merkle_blob_queries: dict[bytes32, list[TreeIndex]] = defaultdict(list)

for _, (left, right) in internal_nodes.items():
for node_hash in (left, right):
Expand Down Expand Up @@ -287,7 +285,7 @@ async def insert_into_data_store_from_file(
for row in rows:
node_hash = bytes32(row["hash"])
root_hash_blob = bytes32(row["root_hash"])
index = row["idx"]
index = TreeIndex(row["idx"])
if node_hash in found_hashes:
raise Exception("Internal error: duplicate node_hash found in nodes table")
merkle_blob_queries[root_hash_blob].append(index)
Expand Down Expand Up @@ -491,8 +489,8 @@ async def get_blob_from_kvid(self, kv_id: KeyOrValueId, store_id: bytes32) -> Op
return bytes(row[0])

async def get_terminal_node(self, kid: KeyId, vid: ValueId, store_id: bytes32) -> TerminalNode:
key = await self.get_blob_from_kvid(kid, store_id)
value = await self.get_blob_from_kvid(vid, store_id)
key = await self.get_blob_from_kvid(kid.raw, store_id)
value = await self.get_blob_from_kvid(vid.raw, store_id)
if key is None or value is None:
raise Exception("Cannot find the key/value pair")

Expand Down Expand Up @@ -526,8 +524,8 @@ async def add_key_value(self, key: bytes, value: bytes, store_id: bytes32) -> tu
"INSERT OR REPLACE INTO hashes (hash, kid, vid, store_id) VALUES (?, ?, ?, ?)",
(
hash,
kid,
vid,
kid.raw,
vid.raw,
store_id,
),
)
Expand Down Expand Up @@ -589,15 +587,15 @@ async def get_existing_hashes(self, node_hashes: list[bytes32], store_id: bytes3
return result

async def add_node_hash(
self, store_id: bytes32, hash: bytes32, root_hash: bytes32, generation: int, index: int
self, store_id: bytes32, hash: bytes32, root_hash: bytes32, generation: int, index: TreeIndex
) -> None:
async with self.db_wrapper.writer() as writer:
await writer.execute(
"""
INSERT INTO nodes(store_id, hash, root_hash, generation, idx)
VALUES (?, ?, ?, ?, ?)
""",
(store_id, hash, root_hash, generation, index),
(store_id, hash, root_hash, generation, index.raw),
)

async def add_node_hashes(self, store_id: bytes32, generation: Optional[int] = None) -> None:
Expand Down Expand Up @@ -1102,7 +1100,7 @@ async def get_keys(
kv_ids = merkle_blob.get_keys_values()
keys: list[bytes] = []
for kid in kv_ids.keys():
key = await self.get_blob_from_kvid(kid, store_id)
key = await self.get_blob_from_kvid(kid.raw, store_id)
if key is None:
raise Exception(f"Unknown key corresponding to KeyId: {kid}")
keys.append(key)
Expand Down
Loading

0 comments on commit 6a86a3d

Please sign in to comment.