Skip to content

Commit

Permalink
Merge pull request #1141 from Phala-Network/ordmap
Browse files Browse the repository at this point in the history
TrieDB: use im::OrdMap to reduce memory consumption
  • Loading branch information
kvinwang authored Mar 20, 2023
2 parents 42ffb56 + fabcda2 commit ac65c00
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 65 deletions.
27 changes: 16 additions & 11 deletions crates/phala-trie-storage/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@ pub type StorageCollection = Vec<(StorageKey, Option<StorageValue>)>;
pub type ChildStorageCollection = Vec<(StorageKey, StorageCollection)>;

pub type InMemoryBackend<H> = TrieBackend<MemoryDB<H>, H>;
pub struct TrieStorage<H: Hasher>(InMemoryBackend<H>);
pub struct TrieStorage<H: Hasher>(InMemoryBackend<H>)
where
H::Out: Ord;

impl<H: Hasher> Default for TrieStorage<H>
where
H::Out: Codec,
H::Out: Codec + Ord,
{
fn default() -> Self {
Self(TrieBackendBuilder::new(Default::default(), Default::default()).build())
Expand All @@ -44,7 +46,7 @@ pub fn load_trie_backend<H: Hasher>(
pairs: impl Iterator<Item = (impl AsRef<[u8]>, impl AsRef<[u8]>)>,
) -> TrieBackend<MemoryDB<H>, H>
where
H::Out: Codec,
H::Out: Codec + Ord,
{
let mut root = Default::default();
let mut mdb = Default::default();
Expand All @@ -65,25 +67,28 @@ pub fn serialize_trie_backend<H: Hasher, S>(
serializer: S,
) -> Result<S::Ok, S::Error>
where
H::Out: Codec + Serialize,
H::Out: Codec + Serialize + Ord,
S: Serializer,
{
let root = trie.root();
let kvs: im::HashMap<_, _> = trie.backend_storage().clone().drain();
(root, kvs).serialize(serializer)
let kvs = trie.backend_storage();
(root, ser::SerAsSeq(kvs)).serialize(serializer)
}

#[cfg(feature = "serde")]
pub fn deserialize_trie_backend<'de, H: Hasher, De>(
deserializer: De,
) -> Result<TrieBackend<MemoryDB<H>, H>, De::Error>
where
H::Out: Codec + Deserialize<'de>,
H::Out: Codec + Deserialize<'de> + Ord,
De: Deserializer<'de>,
{
let (root, kvs): (H::Out, im::HashMap<_, (Vec<u8>, i32)>) =
Deserialize::deserialize(deserializer)?;
let mdb = MemoryDB::from_inner(kvs);
let (root, kvs): (H::Out, Vec<(Vec<u8>, i32)>) = Deserialize::deserialize(deserializer)?;
let mdb = MemoryDB::from_inner(
kvs.into_iter()
.map(|(data, rc)| (H::hash(data.as_ref()), (data, rc)))
.collect(),
);
let backend = TrieBackendBuilder::new(mdb, root).build();
Ok(backend)
}
Expand All @@ -92,7 +97,7 @@ pub fn clone_trie_backend<H: Hasher>(
trie: &TrieBackend<MemoryDB<H>, H>,
) -> TrieBackend<MemoryDB<H>, H>
where
H::Out: Codec,
H::Out: Codec + Ord,
{
let root = trie.root();
let mdb = trie.backend_storage().clone();
Expand Down
99 changes: 47 additions & 52 deletions crates/phala-trie-storage/src/memdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
use hash_db::{
AsHashDB, AsPlainDB, HashDB, HashDBRef, Hasher as KeyHasher, PlainDB, PlainDBRef, Prefix,
};
use im::{hashmap::Entry, HashMap};
use std::{borrow::Borrow, cmp::Eq, hash, marker::PhantomData, mem};
pub(crate) use im::ordmap::{Entry, OrdMap as Map};
use std::{borrow::Borrow, cmp::Eq, hash, marker::PhantomData};

use sp_state_machine::{backend::Consolidate, DefaultError, TrieBackendStorage};
use trie_db::DBValue;
Expand All @@ -28,13 +28,19 @@ impl<T: std::fmt::Debug> MaybeDebug for T {}

pub type GenericMemoryDB<H> = MemoryDB<H, HashKey<H>, trie_db::DBValue>;

impl<H: KeyHasher> Consolidate for GenericMemoryDB<H> {
impl<H: KeyHasher> Consolidate for GenericMemoryDB<H>
where
H::Out: Ord,
{
fn consolidate(&mut self, other: Self) {
MemoryDB::consolidate(self, other)
}
}

impl<H: KeyHasher> TrieBackendStorage<H> for GenericMemoryDB<H> {
impl<H: KeyHasher> TrieBackendStorage<H> for GenericMemoryDB<H>
where
H::Out: Ord,
{
type Overlay = Self;

fn get(
Expand All @@ -58,7 +64,7 @@ where
H: KeyHasher,
KF: KeyFunction<H>,
{
data: HashMap<KF::Key, (T, i32)>,
data: Map<KF::Key, (T, i32)>,
hashed_null_node: H::Out,
null_node_data: T,
_kf: PhantomData<KF>,
Expand All @@ -80,34 +86,6 @@ where
}
}

impl<H, KF, T> PartialEq<MemoryDB<H, KF, T>> for MemoryDB<H, KF, T>
where
H: KeyHasher,
KF: KeyFunction<H>,
<KF as KeyFunction<H>>::Key: Eq + MaybeDebug,
T: Eq + MaybeDebug,
{
fn eq(&self, other: &MemoryDB<H, KF, T>) -> bool {
for a in self.data.iter() {
match other.data.get(a.0) {
Some(v) if v != a.1 => return false,
None => return false,
_ => (),
}
}
true
}
}

impl<H, KF, T> Eq for MemoryDB<H, KF, T>
where
H: KeyHasher,
KF: KeyFunction<H>,
<KF as KeyFunction<H>>::Key: Eq + MaybeDebug,
T: Eq + MaybeDebug,
{
}

pub trait KeyFunction<H: KeyHasher> {
type Key: Send + Sync + Clone + hash::Hash + Eq;

Expand Down Expand Up @@ -181,6 +159,7 @@ where
H: KeyHasher,
T: for<'a> From<&'a [u8]> + Clone,
KF: KeyFunction<H>,
KF::Key: Ord,
{
fn default() -> Self {
Self::from_null_node(&[0u8][..], [0u8][..].into())
Expand All @@ -193,6 +172,7 @@ where
H: KeyHasher,
T: Default + Clone,
KF: KeyFunction<H>,
KF::Key: Ord,
{
/// Remove an element and delete it from storage if reference count reaches zero.
/// If the value was purged, return the old value.
Expand Down Expand Up @@ -233,19 +213,20 @@ where
H: KeyHasher,
T: for<'a> From<&'a [u8]> + Clone,
KF: KeyFunction<H>,
KF::Key: Ord,
{
/// Create a new `MemoryDB` from a given null key/data
pub fn from_null_node(null_key: &[u8], null_node_data: T) -> Self {
MemoryDB {
data: HashMap::default(),
data: Map::default(),
hashed_null_node: H::hash(null_key),
null_node_data,
_kf: Default::default(),
}
}

/// Create a new `MemoryDB` from a given inner hash map.
pub fn from_inner(data: HashMap<KF::Key, (T, i32)>) -> Self {
pub fn from_inner(data: Map<KF::Key, (T, i32)>) -> Self {
MemoryDB {
data,
..Default::default()
Expand All @@ -271,16 +252,6 @@ where
self.data.clear();
}

/// Purge all zero-referenced data from the database.
pub fn purge(&mut self) {
self.data.retain(|_, (_, rc)| *rc != 0);
}

/// Return the internal key-value HashMap, clearing the current state.
pub fn drain(&mut self) -> HashMap<KF::Key, (T, i32)> {
mem::take(&mut self.data)
}

/// Grab the raw information associated with a key. Returns None if the key
/// doesn't exist.
///
Expand All @@ -296,8 +267,8 @@ where
}

/// Consolidate all the entries of `other` into `self`.
pub fn consolidate(&mut self, mut other: Self) {
for (key, (value, rc)) in other.drain() {
pub fn consolidate(&mut self, other: Self) {
for (key, (value, rc)) in other.data {
match self.data.entry(key) {
Entry::Occupied(mut entry) => {
if entry.get().1 < 0 {
Expand All @@ -317,7 +288,7 @@ where
}

/// Get the keys in the database together with number of underlying references.
pub fn keys(&self) -> HashMap<KF::Key, i32> {
pub fn keys(&self) -> Map<KF::Key, i32> {
self.data
.iter()
.filter_map(|(k, v)| {
Expand All @@ -331,12 +302,35 @@ where
}
}

impl<H, KF, T> MemoryDB<H, KF, T>
where
H: KeyHasher,
T: for<'a> From<&'a [u8]> + Clone,
KF: KeyFunction<H>,
KF::Key: Ord,
T: serde::Serialize,
{
pub fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeSeq;

let mut seq = serializer.serialize_seq(Some(self.data.len()))?;
for (_k, v) in self.data.iter() {
seq.serialize_element(&v)?;
}
seq.end()
}
}

impl<H, KF, T> PlainDB<H::Out, T> for MemoryDB<H, KF, T>
where
H: KeyHasher,
T: Default + PartialEq<T> + for<'a> From<&'a [u8]> + Clone + Send + Sync,
KF: Send + Sync + KeyFunction<H>,
KF::Key: Borrow<[u8]> + for<'a> From<&'a [u8]>,
KF::Key: Ord,
{
fn get(&self, key: &H::Out) -> Option<T> {
match self.data.get(key.as_ref()) {
Expand Down Expand Up @@ -390,6 +384,7 @@ where
T: Default + PartialEq<T> + for<'a> From<&'a [u8]> + Clone + Send + Sync,
KF: Send + Sync + KeyFunction<H>,
KF::Key: Borrow<[u8]> + for<'a> From<&'a [u8]>,
KF::Key: Ord,
{
fn get(&self, key: &H::Out) -> Option<T> {
PlainDB::get(self, key)
Expand All @@ -404,6 +399,7 @@ where
H: KeyHasher,
T: Default + PartialEq<T> + AsRef<[u8]> + for<'a> From<&'a [u8]> + Clone + Send + Sync,
KF: KeyFunction<H> + Send + Sync,
KF::Key: Ord,
{
fn get(&self, key: &H::Out, prefix: Prefix) -> Option<T> {
if key == &self.hashed_null_node {
Expand Down Expand Up @@ -486,6 +482,7 @@ where
H: KeyHasher,
T: Default + PartialEq<T> + AsRef<[u8]> + for<'a> From<&'a [u8]> + Clone + Send + Sync,
KF: KeyFunction<H> + Send + Sync,
KF::Key: Ord,
{
fn get(&self, key: &H::Out, prefix: Prefix) -> Option<T> {
HashDB::get(self, key, prefix)
Expand All @@ -501,6 +498,7 @@ where
T: Default + PartialEq<T> + for<'a> From<&'a [u8]> + Clone + Send + Sync,
KF: KeyFunction<H> + Send + Sync,
KF::Key: Borrow<[u8]> + for<'a> From<&'a [u8]>,
KF::Key: Ord,
{
fn as_plain_db(&self) -> &dyn PlainDB<H::Out, T> {
self
Expand All @@ -515,6 +513,7 @@ where
H: KeyHasher,
T: Default + PartialEq<T> + AsRef<[u8]> + for<'a> From<&'a [u8]> + Clone + Send + Sync,
KF: KeyFunction<H> + Send + Sync,
KF::Key: Ord,
{
fn as_hash_db(&self) -> &dyn HashDB<H, T> {
self
Expand All @@ -538,12 +537,8 @@ mod tests {
let mut m = MemoryDB::<KeccakHasher, HashKey<_>, Vec<u8>>::default();
m.remove(&hello_key, EMPTY_PREFIX);
assert_eq!(m.raw(&hello_key, EMPTY_PREFIX).unwrap().1, -1);
m.purge();
assert_eq!(m.raw(&hello_key, EMPTY_PREFIX).unwrap().1, -1);
m.insert(EMPTY_PREFIX, hello_bytes);
assert_eq!(m.raw(&hello_key, EMPTY_PREFIX), None);
m.purge();
assert_eq!(m.raw(&hello_key, EMPTY_PREFIX), None);

let mut m = MemoryDB::<KeccakHasher, HashKey<_>, Vec<u8>>::default();
assert!(m.remove_and_purge(&hello_key, EMPTY_PREFIX).is_none());
Expand Down
21 changes: 19 additions & 2 deletions crates/phala-trie-storage/src/ser.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use serde::{Deserialize, Serialize};
use parity_scale_codec::{Encode, Decode};
use hash_db::Hasher;
use parity_scale_codec::{Decode, Encode};
use scale_info::TypeInfo;
use serde::{Deserialize, Serialize};

use super::{ChildStorageCollection, StorageCollection};

Expand All @@ -16,3 +17,19 @@ pub struct StorageChanges {
pub struct StorageData {
pub inner: Vec<(Vec<u8>, Vec<u8>)>,
}

pub struct SerAsSeq<'a, H: Hasher>(pub &'a crate::MemoryDB<H>)
where
H::Out: Ord;

impl<H: Hasher> Serialize for SerAsSeq<'_, H>
where
H::Out: Ord,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.0.serialize(serializer)
}
}

0 comments on commit ac65c00

Please sign in to comment.