diff --git a/assembly/__tests__/avlTreeContract.ts b/assembly/__tests__/avlTreeContract.ts new file mode 100644 index 00000000..edcc93a2 --- /dev/null +++ b/assembly/__tests__/avlTreeContract.ts @@ -0,0 +1,39 @@ +import { AVLTree } from "../runtime"; + +const tree = new AVLTree("tree"); + +export function insert(key: u32, value: u32): void { + tree.insert(key, value); +} + +export function remove(key: u32): void { + tree.remove(key); +} + +export function has(key: u32): bool { + return tree.has(key); +} + +export function getSome(key: u32): u32 { + return tree.getSome(key); +} + +export function keys(): u32[] { + return tree.keys(u32.MIN_VALUE, u32.MAX_VALUE); +} + +export function values(): u32[] { + return tree.values(u32.MIN_VALUE, u32.MAX_VALUE); +} + +export function isBalanced(): bool { + return tree.isBalanced(); +} + +export function height(): u32 { + return tree.height; +} + +export function size(): u32 { + return tree.size; +} \ No newline at end of file diff --git a/assembly/__tests__/runtime/avl-tree.spec.ts b/assembly/__tests__/runtime/avl-tree.spec.ts new file mode 100644 index 00000000..74f15b9b --- /dev/null +++ b/assembly/__tests__/runtime/avl-tree.spec.ts @@ -0,0 +1,607 @@ +import { AVLTree, MapEntry } from "../../runtime"; +import { RNG } from "../../runtime/math"; +import { Context } from "../../vm"; + +let tree: AVLTree; +let _closure_var1: u32; + +let _closure_rng: RNG; +function random(n: i32): u32[] { + const a = new Array(n); + _closure_rng = new RNG(2 * n); + return a.map((_): u32 => _closure_rng.next()); +} + +function range(start: u32, end: u32): u32[] { + return (new Array(end - start)).map((_, i) => i); +} + +function maxTreeHeight(n: f64): u32 { + // From near-sdk-rs TreeMap: + // h <= C * log2(n + D) + B + // where: + // C =~ 1.440, D =~ 1.065, B =~ 0.328 + // (source: https://en.wikipedia.org/wiki/AVL_tree) + const B: f64 = -0.328; + const C: f64 = 1.440; + const D: f64 = 1.065; + + const h = C * Math.log2(n + D) + B; + return Math.ceil(h) as u32; +} + +// Convenience method for tests that insert then remove some values +function insertThenRemove(t: AVLTree, keysToInsert: u32[], keysToRemove: u32[]): void { + const map = new Map(); + + insertKeys(t, keysToInsert, map); + removeKeys(t, keysToRemove, map); +}; + +function insertKeys(t: AVLTree, keysToInsert: u32[], map: Map | null = null): void { + for (let i = 0; i < keysToInsert.length; ++i) { + const key = keysToInsert[i]; + expect(t.has(key)).toBeFalsy("tree.has() should return false for key that has not been inserted yet. Are duplicate keys being inserted?"); + + t.insert(key, i); + expect(t.getSome(key)).toStrictEqual(i); + + if (map) { + map.set(key, i); + expect(t.getSome(key)).toStrictEqual(map.get(key)); + } + } +} + +function removeKeys(t: AVLTree, keysToRemove: u32[], map: Map | null = null): void { + for (let i = 0; i < keysToRemove.length; ++i) { + const key = keysToRemove[i]; + if (map && map.has(key)) { + expect(t.getSome(key)).toStrictEqual(map.get(key)); + map.delete(key); + } + + t.remove(key); + expect(t.has(key)).toBeFalsy("tree.has() should return false for removed key"); + } +} + +function generateRandomTree(t: AVLTree, n: u32): Map { + const map = new Map(); + const keysToInsert = random(2 * n); + const keysToRemove: u32[] = []; + for (let i = 0; i < i32(n); ++i) { + keysToRemove.push(keysToInsert[i]); + } + + insertKeys(t, keysToInsert, map); + removeKeys(t, keysToRemove, map); + + return map; +} + +describe("AVLTrees should handle", () => { + + beforeAll(() => { + tree = new AVLTree("tree1"); + }); + + afterEach(() => { + tree.clear(); + }); + + it("adds key-value pairs", () => { + const key = 1; + const value = 2; + + tree.set(key, value); + + expect(tree.has(key)).toBeTruthy("The tree should have the key"); + expect(tree.containsKey(key)).toBeTruthy("The tree should contain the key"); + expect(tree.getSome(key)).toStrictEqual(value); + }); + + it("checks for non-existent keys", () => { + const key = 1; + + expect(tree.has(key)).toBeFalsy("tree should not have the key"); + expect(tree.containsKey(key)).toBeFalsy("tree should not contain the key"); + }); + + throws("if attempting to get a non-existent key", () => { + const key = 1; + + tree.getSome(key); + }); + + it("is empty", () => { + const key = 42; + _closure_var1 = key; + + expect(tree.size).toStrictEqual(0); + expect(tree.height).toStrictEqual(0); + expect(tree.has(key)).toBeFalsy("empty tree should not have the key"); + expect(tree.containsKey(key)).toBeFalsy("empty tree should not have the key"); + expect(() => { tree.min() }).toThrow("min() should throw for empty tree"); + expect(() => { tree.max() }).toThrow("max() should throw for empty tree"); + expect(() => { tree.lower(_closure_var1) }).toThrow("lower() should throw for empty tree"); + expect(() => { tree.lower(_closure_var1) }).toThrow("higher() should throw for empty tree"); + }); + + it("rotates left twice when inserting 3 keys in decreasing order", () => { + expect(tree.height).toStrictEqual(0); + + tree.insert(3, 3); + expect(tree.height).toStrictEqual(1); + + tree.insert(2, 2); + expect(tree.height).toStrictEqual(2); + + tree.insert(1, 1); + expect(tree.height).toStrictEqual(2); + + expect(tree.rootKey).toStrictEqual(2); + }); + + it("rotates right twice when inserting 3 keys in increasing order", () => { + expect(tree.height).toStrictEqual(0); + + tree.insert(1, 1); + expect(tree.height).toStrictEqual(1); + + tree.insert(2, 2); + expect(tree.height).toStrictEqual(2); + + tree.insert(3, 3); + expect(tree.height).toStrictEqual(2); + + expect(tree.rootKey).toStrictEqual(2); + }); + + it("sets and gets n key-value pairs in ascending order", () => { + const n: u32 = 30; + const cases: u32[] = range(0, n*2); + + let counter = 0; + for (let i = 0; i < cases.length; ++i) { + const k = cases[i]; + if (k % 2 === 0) { + counter += 1; + tree.insert(k, counter); + } + } + + counter = 0; + for (let i = 0; i < cases.length; ++i) { + const k = cases[i]; + if (k % 2 === 0) { + counter += 1; + expect(tree.getSome(k)).toStrictEqual(counter); + } else { + expect(tree.has(k)).toBeFalsy(`tree should not contain key ${k}`); + } + } + + expect(tree.height).toBeLessThanOrEqual(maxTreeHeight(n)); + }); + + it("sets and gets n key-value pairs in descending order", () => { + const n: u32 = 30; + const cases: u32[] = range(0, n*2).reverse(); + + let counter = 0; + for (let i = 0; i < cases.length; ++i) { + const k = cases[i]; + if (k % 2 === 0) { + counter += 1; + tree.insert(k, counter); + } + } + + counter = 0; + for (let i = 0; i < cases.length; ++i) { + const k = cases[i]; + if (k % 2 === 0) { + counter += 1; + expect(tree.getSome(k)).toStrictEqual(counter); + } else { + expect(tree.has(k)).toBeFalsy(`tree should not contain key ${k}`); + } + } + + expect(tree.height).toBeLessThanOrEqual(maxTreeHeight(n)); + }); + + it("sets and gets n random key-value pairs", () => { + Context.setPrepaid_gas(u64.MAX_VALUE); + // TODO setup free gas env to prevent gas exceeded error, and test larger trees + range(1, 8).forEach(k => { // tree size is 2^(k-1) + const n = 1 << k; + const input: u32[] = random(n); + + input.forEach(x => { + tree.insert(x, 42); + }); + + input.forEach(x => { + expect(tree.getSome(x)).toStrictEqual(42); + }); + + expect(tree.height).toBeLessThanOrEqual(maxTreeHeight(n)); + + tree.clear(); + }); + }); + + it("gets the minimum key", () => { + const n: u32 = 30; + const keys = random(n); + + keys.forEach(x => { + tree.insert(x, 1); + }); + + const min = (keys.sort(), keys[0]); + expect(tree.min()).toStrictEqual(min); + }); + + it("gets the maximum key", () => { + const n: u32 = 30; + const keys = random(n); + + keys.forEach(x => { + tree.insert(x, 1); + }); + + const max = (keys.sort(), keys[keys.length-1]); + expect(tree.max()).toStrictEqual(max); + }); + + it("gets the key lower than the given key", () => { + const keys: u32[] = [10, 20, 30, 40, 50]; + + keys.forEach(x => { + tree.insert(x, 1); + }); + + expect(() => { tree.lower(5) }).toThrow("5 is lower than tree.min(), which is 10"); + expect(() => { tree.lower(10) }).toThrow("10 is equal to tree.min(), which is 10"); + expect(tree.lower(11)).toStrictEqual(10); + expect(tree.lower(20)).toStrictEqual(10); + expect(tree.lower(49)).toStrictEqual(40); + expect(tree.lower(50)).toStrictEqual(40); + expect(tree.lower(51)).toStrictEqual(50); + }); + + it("gets the key higher than the given key", () => { + const keys: u32[] = [10, 20, 30, 40, 50]; + + keys.forEach(x => { + tree.insert(x, 1); + }); + + expect(tree.higher(5)).toStrictEqual(10); + expect(tree.higher(10)).toStrictEqual(20); + expect(tree.higher(11)).toStrictEqual(20); + expect(tree.higher(20)).toStrictEqual(30); + expect(tree.higher(49)).toStrictEqual(50); + expect(() => { tree.higher(50) }).toThrow("50 is equal to tree.max(), which is 50"); + expect(() => { tree.higher(51) }).toThrow("51 is greater than tree.max(), which is 50"); + }); + + it("gets the key lower than or equal to the given key", () => { + const keys: u32[] = [10, 20, 30, 40, 50]; + + keys.forEach(x => { + tree.insert(x, 1); + }); + + expect(() => { tree.floorKey(5) }).toThrow("5 is lower than tree.min(), which is 10"); + expect(tree.floorKey(10)).toStrictEqual(10); + expect(tree.floorKey(11)).toStrictEqual(10); + expect(tree.floorKey(20)).toStrictEqual(20); + expect(tree.floorKey(49)).toStrictEqual(40); + expect(tree.floorKey(50)).toStrictEqual(50); + expect(tree.floorKey(51)).toStrictEqual(50); + }); + + it("gets the key greater than or equal to the given key", () => { + const keys: u32[] = [10, 20, 30, 40, 50]; + + keys.forEach(x => { + tree.insert(x, 1); + }); + + expect(tree.ceilKey(5)).toStrictEqual(10); + expect(tree.ceilKey(10)).toStrictEqual(10); + expect(tree.ceilKey(11)).toStrictEqual(20); + expect(tree.ceilKey(20)).toStrictEqual(20); + expect(tree.ceilKey(49)).toStrictEqual(50); + expect(tree.ceilKey(50)).toStrictEqual(50); + expect(() => { tree.ceilKey(51) }).toThrow("51 is greater than tree.max(), which is 50"); + }); + + it("removes 1 key", () => { + const key = 1; + const value = 2; + + tree.insert(key, value); + expect(tree.getSome(key)).toStrictEqual(value); + expect(tree.size).toStrictEqual(1); + + tree.remove(key); + expect(tree.has(key)).toBeFalsy(`tree should not contain key ${key}`); + expect(tree.size).toStrictEqual(0); + }); + + it("removes non-existent key", () => { + const key = 1; + const value = 2; + + tree.insert(key, value); + expect(tree.getSome(key)).toStrictEqual(value); + expect(tree.size).toStrictEqual(1); + + tree.remove(value); + expect(tree.getSome(key)).toStrictEqual(value); + expect(tree.size).toStrictEqual(1); + }); + + it("removes 3 keys in descending order", () => { + const keys: u32[] = [3, 2, 1]; + insertThenRemove(tree, keys, keys); + expect(tree.size).toStrictEqual(0); + }); + + it("removes 3 keys in ascending order", () => { + const keys: u32[] = [1, 2, 3]; + insertThenRemove(tree, keys, keys); + expect(tree.size).toStrictEqual(0); + }); + + it("removes 7 random keys", () => { + const keys: u32[] = [ + 2104297040, + 552624607, + 4269683389, + 3382615941, + 155419892, + 4102023417, + 1795725075 + ]; + insertThenRemove(tree, keys, keys); + expect(tree.size).toStrictEqual(0); + }); + + // test_remove_7_regression_2() + + it("removes 9 random keys", () => { + const keys: u32[] = [ + 1186903464, + 506371929, + 1738679820, + 1883936615, + 1815331350, + 1512669683, + 3581743264, + 1396738166, + 1902061760 + ]; + insertThenRemove(tree, keys, keys); + expect(tree.size).toStrictEqual(0); + }); + + it("removes 20 random keys", () => { + const keys: u32[] = [ + 552517392, 3638992158, 1015727752, 2500937532, 638716734, + 586360620, 2476692174, 1425948996, 3608478547, 757735878, + 2709959928, 2092169539, 3620770200, 783020918, 1986928932, + 200210441, 1972255302, 533239929, 497054557, 2137924638 + ]; + insertThenRemove(tree, keys, keys); + expect(tree.size).toStrictEqual(0); + }); + + // test_remove_7_regression() + + it("inserts 8 keys then removes 4 keys", () => { + const keysToInsert: u32[] = [882, 398, 161, 76]; + const keysToRemove: u32[] = [242, 687, 860, 811]; + + insertThenRemove(tree, keysToInsert.concat(keysToRemove), keysToRemove); + + expect(tree.size).toStrictEqual(keysToInsert.length); + + keysToInsert.forEach((key, i) => { + expect(tree.getSome(key)).toStrictEqual(i); + }); + }); + + it("removes n random keys", () => { + const n: u32 = 20; + const keys = random(n); + + const set = new Set(); + + for (let i = 0; i < keys.length; ++i) { + const key = keys[i]; + tree.insert(key, i); + set.add(key); + } + + expect(tree.size).toStrictEqual(set.size); + + keys.forEach((key, i) => { + expect(tree.getSome(key)).toStrictEqual(i); + tree.remove(key); + expect(tree.has(key)).toBeFalsy(`tree should not contain key ${key}`); + }); + + expect(tree.size).toStrictEqual(0); + }); + + it("removes the root of the tree", () => { + tree.insert(2, 1); + tree.insert(3, 1); + tree.insert(1, 1); + tree.insert(4, 1); + + expect(tree.rootKey).toStrictEqual(2); + tree.remove(2); + expect(tree.rootKey).toStrictEqual(3); + + expect(tree.getSome(1)).toStrictEqual(1); + expect(() => { tree.getSome(2) }).toThrow("tree should throw when getting removed key (root of the tree)"); + expect(tree.getSome(3)).toStrictEqual(1); + expect(tree.getSome(4)).toStrictEqual(1); + }); + + it("inserts 2 keys then removes 2 keys", () => { + const keysToInsert: u32[] = [11760225, 611327897]; + const keysToRemove: u32[] = [2982517385, 1833990072]; + + insertThenRemove(tree, keysToInsert, keysToRemove); + + expect(tree.height).toBeLessThanOrEqual(maxTreeHeight(tree.size)); + }); + + it("inserts n duplicate keys", () => { + range(0, 30).forEach((key, i) => { + tree.insert(key, i); + tree.insert(42, i); + }); + + expect(tree.getSome(42)).toStrictEqual(29); // most recent value inserted for key 42 + expect(tree.size).toStrictEqual(31); + }); + + it("inserts 2n keys then removes n random keys", () => { + range(1, 4).forEach(k => { + const set = new Set(); + + const n = 1 << k; + const keysToInsert = random(n); + const keysToRemove = random(n); + const allKeys = keysToInsert.concat(keysToRemove); + + insertThenRemove(tree, allKeys, keysToRemove); + for (let i = 0; i < allKeys.length; ++i) { + const key = allKeys[i]; + set.add(key); + } + + for (let i = 0; i < keysToRemove.length; ++i) { + const key = allKeys[i]; + set.delete(key); + } + + expect(tree.size).toStrictEqual(set.size); + expect(tree.height).toBeLessThanOrEqual(maxTreeHeight(n)); + + tree.clear(); + expect(tree.size).toStrictEqual(0); + }); + }); + + it("does nothing when removing while empty", () => { + expect(tree.size).toStrictEqual(0); + tree.remove(1); + expect(tree.size).toStrictEqual(0); + }); + + it("returns an equivalent array", () => { + tree.insert(1, 41); + tree.insert(2, 42); + tree.insert(3, 43); + + const a = [ + new MapEntry(1, 41), + new MapEntry(2, 42), + new MapEntry(3, 43) + ]; + expect(tree.entries(1, 4)).toStrictEqual(a); + expect(tree.entries(1, 3, true)).toStrictEqual(a); + }); + + it("returns an empty array when empty", () => { + expect(tree.entries(0, 0)).toStrictEqual([]); + }); + + it("returns a range of values for a given start key and end key", () => { + const keys = [10, 20, 30, 40, 50, 45, 35, 25, 15, 5]; + const values = [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + + for (let i = 0; i < keys.length; ++i) { + const key = keys[i]; + tree.insert(key, values[i]); + } + + expect(tree.values(20, 30)).toStrictEqual([2, 8]); + expect(tree.values(11, 41)).toStrictEqual([9, 2, 8, 3, 7, 4]); + expect(tree.values(20, 41)).toStrictEqual([2, 8, 3, 7, 4]); + expect(tree.values(21, 45)).toStrictEqual([8, 3, 7, 4]); + expect(tree.values(26, 30)).toStrictEqual([]); + expect(tree.values(25, 25)).toStrictEqual([8]); + expect(tree.values(26, 25)).toStrictEqual([]); + expect(tree.values(40, 50)).toStrictEqual([4, 6]); + expect(tree.values(40, 51)).toStrictEqual([4, 6, 5]); + expect(tree.values(4, 5)).toStrictEqual([]); + expect(tree.values(5, 51)).toStrictEqual([10, 1, 9, 2, 8, 3, 7, 4, 6, 5]); + }); + + it("returns a range of values for a given start key and inclusive end key", () => { + const keys = [10, 20, 30, 40, 50, 45, 35, 25, 15, 5]; + const values = [ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; + + for (let i = 0; i < keys.length; ++i) { + const key = keys[i]; + tree.insert(key, values[i]); + } + + expect(tree.values(20, 30, true)).toStrictEqual([2, 8, 3]); + expect(tree.values(11, 41, true)).toStrictEqual([9, 2, 8, 3, 7, 4]); + expect(tree.values(20, 41, true)).toStrictEqual([2, 8, 3, 7, 4]); + expect(tree.values(21, 45, true)).toStrictEqual([8, 3, 7, 4, 6]); + expect(tree.values(26, 30, true)).toStrictEqual([3]); + expect(tree.values(25, 25, true)).toStrictEqual([8]); + expect(tree.values(26, 25, true)).toStrictEqual([]); + expect(tree.values(40, 50, true)).toStrictEqual([4, 6, 5]); + expect(tree.values(40, 51, true)).toStrictEqual([4, 6, 5]); + expect(tree.values(4, 5, true)).toStrictEqual([10]); + expect(tree.values(5, 51, true)).toStrictEqual([10, 1, 9, 2, 8, 3, 7, 4, 6, 5]); + }); + + it("remains balanced after some insertions and deletions", () => { + const keysToInsert: u32[] = [2, 3, 4]; + const keysToRemove: u32[] = [0, 0, 0, 1]; + + insertThenRemove(tree, keysToInsert, keysToRemove); + expect(tree.isBalanced()).toBeTruthy(); + }); + + it("remains balanced after more insertions and deletions", () => { + const keysToInsert: u32[] = [1, 2, 0, 3, 5, 6]; + const keysToRemove: u32[] = [0, 0, 0, 3, 5, 6, 7, 4]; + + insertThenRemove(tree, keysToInsert, keysToRemove); + expect(tree.isBalanced()).toBeTruthy(); + }); + + it("remains balanced and sorted after 2n insertions and n deletions", () => { + Context.setPrepaid_gas(u64.MAX_VALUE); + // TODO setup free gas env to prevent gas exceeded error, and test larger trees + const n: u32 = 33; + const map = generateRandomTree(tree, n); + const sortedKeys: u32[] = map.keys().sort(); + const sortedValues: u32[] = []; + for (let i = 0; i < sortedKeys.length; ++i) { + sortedValues.push(map.get(sortedKeys[i])); + } + + expect(tree.size).toStrictEqual(n); + expect(tree.isBalanced()).toBeTruthy(); + expect(tree.height).toBeLessThanOrEqual(maxTreeHeight(n)); + expect(tree.keys(u32.MIN_VALUE, u32.MAX_VALUE)).toStrictEqual(sortedKeys); + expect(tree.values(u32.MIN_VALUE, u32.MAX_VALUE)).toStrictEqual(sortedValues); + }); +}); diff --git a/assembly/runtime/collections/avlTree.ts b/assembly/runtime/collections/avlTree.ts new file mode 100644 index 00000000..6c956ea4 --- /dev/null +++ b/assembly/runtime/collections/avlTree.ts @@ -0,0 +1,658 @@ +import { PersistentMap, PersistentVector, collections } from "../collections"; +import { MapEntry } from "./util"; +import { storage } from "../storage"; + +@nearBindgen +class Nullable { + constructor(public val: T) { + this.val = val; + } +} + +type NodeId = i32; + +class ChildParentPair { + constructor(public child: AVLTreeNode, public parent: AVLTreeNode) { + this.child = child; + this.parent = parent; + } +} + +@nearBindgen +class AVLTreeNode { + constructor( + public id: NodeId, + public key: K, + public left: Nullable | null = null, + public right: Nullable | null = null, + public height: u32 = 1 + ) { + this.id = id; + this.key = key; + this.left = left; + this.right = right; + this.height = height; + } +} + +@nearBindgen +export class AVLTree { + private _elementPrefix: string; + private _val: PersistentMap; + private _tree: PersistentVector>; + private _rootId: Nullable | null; + + /** + * A string name is used as a prefix for writing keys to storage. + */ + constructor(name: string) { + this._elementPrefix = name + collections._KEY_ELEMENT_SUFFIX; + this._val = new PersistentMap(this._elementPrefix + "val"); + this._tree = new PersistentVector>(this._elementPrefix + "tree"); + this._rootId = storage.get>(this._elementPrefix + "root"); + } + + /** + * @returns Number of elements in the tree. + */ + get size(): u32 { + return this._tree.length; + } + // alias to match rust sdk + get len(): u32 { + return this.size; + } + + /** + * @returns Height of the tree. + */ + get height(): u32 { + return this.nodeHeight(this.rootId); + } + + /** + * @returns Whether the key is present in the tree. + */ + has(key: K): bool { + return this._val.contains(key); + } + // alias to match rust sdk + containsKey(key: K): bool { + return this.has(key); + } + + /** + * If key is not present in the tree, a new node is added with the value. + * Otherwise update the node with the new value. + * + * @param key Key of the element. + * @param value The new value of the element. + */ + set(key: K, value: V): void { + if (!this.has(key)) { + this.rootId = new Nullable(this.insertAt(this.rootNode, key).id); + } + + this._val.set(key, value); + } + // alias to match rust sdk + insert(key: K, value: V): void { + this.set(key, value); + } + + /** + * Retrieves a related value for a given key or uses the `defaultValue` if not key is found + * + * @param key Key of the element. + * @returns Value for the given key or the default value. + */ + get(key: K, defaultValue: V | null = null): V | null { + return this._val.get(key, defaultValue); + } + + /** + * Retrieves the related value for a given key, or throws error "key not found" + * + * @param key Key of the element. + * @param defaultValue The default value if the key is not present. + * @returns Value for the given key or the default value. + */ + getSome(key: K): V { + return this._val.getSome(key); + } + + /** + * Delete element with key if present, otherwise do nothing. + * + * @param key Key to remove. + */ + delete(key: K): void { + if (this.has(key)) { + this.rootId = this.doRemove(key); + this._val.delete(key); + } + } + // alias to match rust sdk + remove(key: K): void { + this.delete(key); + } + + + /** + * Get a range of values from a start key (inclusive) to an end key (exclusive by default). + * If end is greater than max key, include start to max inclusive. + * + * @param start Key for lower bound (inclusive). + * @param end Key for upper bound (exclusive by default). + * @param inclusive Set to false if upper bound should be exclusive, true if upper bound should be inclusive + * @returns Range of values corresponding to keys within start and end bounds. + */ + values(start: K, end: K, inclusive: boolean = false): V[] { + const keys = this.keys(start, end, inclusive); + const values: V[] = []; + for (let i = 0; i < keys.length; ++i) { + const key = keys[i]; + values.push(this._val.getSome(key)); + } + return values; + } + + /** + * Get a range of keys from a start key (inclusive) to an end key (exclusive by default). + * If end is greater than max key, include start to max inclusive. + * + * @param start Key for lower bound (inclusive). + * @param end Key for upper bound (exclusive by default). + * @param inclusive Set to false if upper bound should be exclusive, true if upper bound should be inclusive + * @returns Range of keys within start and end bounds. + */ + keys(start: K, end: K, inclusive: boolean = false): K[] { + const rootNode = this.rootNode; + const sorted: K[] = []; + if (rootNode) { + const visited: AVLTreeNode[] = []; + + this.traverseLeft(start, rootNode, visited); + + while (visited.length) { + const node = visited.pop(); + + if ( + // key must always be gte start bound + node.key >= start && ( + // if start and end bound are equal, + // end bound becomes lte instead of strictly less than + start === end || inclusive ? + node.key <= end : + node.key < end + ) + ) { + sorted.push(node.key); + } + + if (node.key < end && node.right) { + this.traverseLeft(start, this.node(node.right)!, visited); + } + } + } + + return sorted; + } + + private traverseLeft(start: K, node: AVLTreeNode, visited: AVLTreeNode[]): void { + while (node.key >= start && node.left) { + visited.push(node); + node = this.node(node.left)!; + } + visited.push(node); + } + + /** + * Get a range of entries from a start key (inclusive) to an end key (exclusive by default). + * If end is greater than max key, include start to max inclusive. + * + * @param start Key for lower bound (inclusive). + * @param end Key for upper bound (exclusive by default). + * @param inclusive Set to false if upper bound should be exclusive, true if upper bound should be inclusive + * @returns Range of entries corresponding to keys within start and end bounds. + */ + entries(start: K, end: K, inclusive: boolean = false): MapEntry[] { + const keys = this.keys(start, end, inclusive); + const entries: MapEntry[] = []; + for (let i = 0; i < keys.length; ++i) { + const key = keys[i]; + entries.push(new MapEntry(key, this._val.getSome(key))); + } + return entries; + } + // alias to match rust sdk + range(start: K, end: K, inclusive: boolean = false): MapEntry[] { + return this.entries(start, end, inclusive); + } + + /** + * Returns minimum key. + * Throws if tree is empty. + * @returns Minimum key. + */ + min(): K { + return this.minAt(this.rootNode!).child.key; + } + + /** + * Returns maximum key. + * Throws if tree is empty. + * @returns Maximum key. + */ + max(): K { + return this.maxAt(this.rootNode!).child.key; + } + + /** + * Returns the maximum key that is strictly less than the key. + * Throws if empty or if key is lower than or equal to `this.min()`. + * @param key Key for lower bound (exclusive). + * @returns Maximum key that is strictly less than given key. + */ + lower(key: K): K { + let root = this.rootNode!; + + while (root.left || root.right) { + if (root.key >= key && root.left) { + root = this.node(root.left)!; + } else if (root.right) { + const rightNode = this.node(root.right)!; + if (rightNode.key < key) { + root = rightNode; + } else { + break; + } + } + } + + if (root.key >= key) throw new Error(`key is less than mininum key in tree`) + else return root.key; + } + + /** + * Returns the minimum key that is strictly greater than the key. + * Throws if empty or if key is higher than or equal to `this.max()`. + * @param key Key for upper bound (exclusive). + * @returns Minimum key that is strictly greater than given key. + */ + higher(key: K): K { + let root = this.rootNode!; + + while (root.left || root.right) { + if (root.key <= key && root.right) { + root = this.node(root.right)!; + } else if (root.right) { + const leftNode = this.node(root.left)!; + if (leftNode.key > key) { + root = leftNode; + } else { + break; + } + } + } + + if (root.key <= key) throw new Error(`key is greater than maximum key in tree`) + else return root.key; + } + + /** + * Returns the maximum key that is less or equal than the key. + * Throws if empty or if key is lower than `this.min()`. + * @param key Key for lower bound (inclusive). + * @returns Maximum key that is less than or equal to given key. + */ + lowerOrEqual(key: K): K { + return this.has(key) ? key : this.lower(key); + } + // alias to match rust sdk + floorKey(key: K): K { + return this.lowerOrEqual(key); + } + + /** + * Returns the minimum key that is greater or equal than the key. + * Throws if empty or if key is higher than `this.max()`. + * @param key Key for upper bound (inclusive). + * @returns Minimum key that is greater or equal to given key. + */ + higherOrEqual(key: K): K { + return this.has(key) ? key : this.higher(key); + } + // alias to match rust sdk + ceilKey(key: K): K { + return this.higherOrEqual(key); + } + + /** + * Removes all key-value pairs from the tree + */ + clear(): void { + while(this.size > 0) { + this._val.delete(this._tree.popBack().key); + } + this.rootId = null; + } + + // useful for debugging + private toString(): string { + const a: string[] = ["\n"]; + for (let i: i32 = 0; i < i32(this.size); ++i) { + const node = this._tree[i]; + const key = u32(node.key).toString(); + const index = node.id.toString(); + const leftKey = node.left ? u32(this.node(node.left)!.key).toString() : "null"; + const rightKey = node.right ? u32(this.node(node.right)!.key).toString() : "null"; + const isRoot = node.id === this.rootId!.val ? "true" : "false" + const childrenProperties: string[] = [leftKey, rightKey, isRoot]; + const nodeProperties: string[] = [key, ",", index, ":", childrenProperties.join()]; + a.push(nodeProperties.join(" ")); + } + return a.join("\n"); + } + + /** + * ********************************** + * AVL Tree core routines + * ********************************** + */ + + // returns root key of the tree. + get rootKey(): K { + assert(!isNull(this.rootNode), "rootNode must be defined"); + return this.rootNode!.key; + } + + private set rootId(rootId: Nullable | null) { + this._rootId = rootId; + storage.set(this._elementPrefix + "root", this._rootId); + } + + private get rootId(): Nullable | null { + return this._rootId; + } + + // returns the root node of the tree, if it exists. + // returns null otherwise + private get rootNode(): AVLTreeNode | null { + return this.node(this.rootId); + } + + // returns the height for a given node + private nodeHeight(id: Nullable | null): u32 { + return id ? this._tree[id.val].height : 0; + } + + // returns the difference in heights between a node's left and right subtrees + private balance(node: AVLTreeNode): i32 { + return this.nodeHeight(node.left) - this.nodeHeight(node.right); + } + + // updates the height for a given node based on the heights of its subtrees + private updateHeight(node: AVLTreeNode): void { + node.height = 1 + max(this.nodeHeight(node.left), this.nodeHeight(node.right)); + this._tree[node.id] = node; + } + + // returns the node for the given id (index into underlying array this._tree) + private node(id: Nullable | null): AVLTreeNode | null { + return id ? this._tree[id.val] : null; + } + + // inserts a new key into the tree and + // recursively updates the height of each node in the tree, + // performing rotations as needed from bottom to top + // to maintain the AVL tree balance invariant + private insertAt(parentNode: AVLTreeNode | null, key: K): AVLTreeNode { + if (!parentNode) { + const node = new AVLTreeNode(this.size, key); + this._tree.push(node); + return node; + } else { + if (key < parentNode.key) { + parentNode.left = new Nullable(this.insertAt(this.node(parentNode.left), key).id); + } else if (key > parentNode.key) { + parentNode.right = new Nullable(this.insertAt(this.node(parentNode.right), key).id); + } else { + throw new Error("Key already exists, but does not have an associated value"); + } + + this.updateHeight(parentNode); + + return this.enforceBalance(parentNode); + } + } + + // given a node + // performs a single set left and right rotations to maintain AVL tree balance invariant + private enforceBalance(node: AVLTreeNode): AVLTreeNode { + const balance = this.balance(node); + if (balance > 1) { + // implies left child must exist, since balance = left.height - right.height + const leftChildNode = this.node(node.left)!; + if (this.balance(leftChildNode) < 0) { + node.left = new Nullable(this.rotateRight(leftChildNode).id); + } + return this.rotateLeft(node); + } else if (balance < -1) { + // implies right child must exist + const rightChildNode = this.node(node.right)!; + if (this.balance(rightChildNode) > 0) { + node.right = new Nullable(this.rotateLeft(rightChildNode).id); + } + return this.rotateRight(node); + } else { + // node is already balanced + return node; + } + } + + // given a node + // performs a righthand rotation + // node child + // \ / + // child -> node + // / \ + // child.left child.right + private rotateRight(node: AVLTreeNode): AVLTreeNode { + const childNode = this.node(node.right)!; + node.right = childNode.left; + childNode.left = new Nullable(node.id); + + this.updateHeight(node); + this.updateHeight(childNode); + + return childNode; + } + + // given a node + // performs a lefthand rotation + // node child + // / \ + // child -> node + // \ / + // child.right child.right + private rotateLeft(node: AVLTreeNode): AVLTreeNode { + const childNode = this.node(node.left)!; + node.left = childNode.right; + childNode.right = new Nullable(node.id); + + this.updateHeight(node); + this.updateHeight(childNode); + + return childNode; + } + + // removes the given key from the tree, maintaining the AVL balance invariant + private doRemove(key: K): Nullable | null { + const nodeAndParent = this.lookupAt(this.rootNode!, key); + let node = nodeAndParent.child; + let parentNode = nodeAndParent.parent; + let successorId: Nullable | null; + + if (!node.left && !node.right) { + // node to remove is a leaf node + if (parentNode.key < node.key) { + parentNode.right = null; + } else { + parentNode.left = null; + } + successorId = null; + } else if (!node.left) { + // node to remove has 1 right child + // replace node to remove with its right child + if (parentNode.key < node.key) { + parentNode.right = node.right; + } else { + parentNode.left = node.right; + } + successorId = node.right; + } else if (!node.right) { + // node to remove has 1 left child + // replace node to remove with its left child + if (parentNode.key < node.key) { + parentNode.right = node.left; + } else { + parentNode.left = node.left; + } + successorId = node.left; + } else { + // node has 2 children, search for successor + const isLeftLeaning = this.balance(node) >= 0; + const nodes = isLeftLeaning ? + // node to remove is left leaning, so search left subtree + this.maxAt(this.node(node.left)!, node) : + // node to remove is right leaning, so search right subtree + this.minAt(this.node(node.right)!, node); + + const successor = nodes.child; + + // node to remove and parentNode can be the same node on small trees (2 levels, 2-3 nodes) + // if so, make parentNode point to node + parentNode = nodes.parent.id === node.id ? node : nodes.parent; + + const successorIsLeftChild = parentNode.left ? parentNode.left!.val === successor.id : false; + + // remove successor from its parent, and link the successor's child to its grandparent + if (successorIsLeftChild) { + parentNode.left = isLeftLeaning ? successor.left : successor.right; + } else { + parentNode.right = isLeftLeaning ? successor.left : successor.right; + } + + // take successor's key, and update the node to remove + node.key = successor.key; + this._tree[node.id] = node; + successorId = new Nullable(node.id); + + // set node to point to successor, so it is removed from the tree + node = successor; + } + + this.updateHeight(parentNode); + this.swapRemove(node); + + return this.size > 0 && this.rootNode ? + new Nullable(this.rebalanceAt(this.rootNode!, parentNode.key).id) : + successorId; + } + + // removes the given node from the tree, + // and replaces it with the last node in the underlying array (this._tree) + private swapRemove(node: AVLTreeNode): void { + if (node.id === this.size - 1) { + // node is last element in tree, so no swapping needed + if (node.id === this.rootId!.val) { + this.rootId = null; + } + } else { + const lastNode = this._tree[this.size - 1]; + const parentNode = this.lookupAt(this.rootNode!, lastNode.key).parent; + + if (lastNode.id === this.rootId!.val) { + this.rootId = new Nullable(node.id); + } + + // check to make sure that parentNode and lastNode do not overlap + if (parentNode.id !== lastNode.id) { + // make lastNode's parent point to new index (index of node that lastNode is replacing) + if (parentNode.left ? parentNode.left!.val === lastNode.id : false) { + parentNode.left = new Nullable(node.id); + } else { + parentNode.right = new Nullable(node.id); + } + + // update the parentNode + this._tree[parentNode.id] = parentNode; + } + + // update index of lastNode + lastNode.id = node.id; + this._tree[lastNode.id] = lastNode; + } + + this._tree.pop(); + } + + // given a starting node + // returns the leftmost (min) descendant node, and its parent + private minAt(root: AVLTreeNode, parentNode: AVLTreeNode | null = null): ChildParentPair { + return root.left ? + this.minAt(this.node(root.left)!, root) : + new ChildParentPair(root, parentNode ? parentNode : root); + } + + // given a starting node + // returns the rightmost (max) descendant node, and its parent + private maxAt(root: AVLTreeNode, parentNode: AVLTreeNode | null = null): ChildParentPair { + return root.right ? + this.maxAt(this.node(root.right)!, root) : + new ChildParentPair(root, parentNode ? parentNode : root); + } + + // given a key and a starting node + // returns the node with the associated key, as well as its parent (if it exists) + // caution: this method assumes the key exists in the tree, and will throw otherwise + private lookupAt(root: AVLTreeNode, key: K, parentNode: AVLTreeNode | null = null): ChildParentPair { + return root.key === key ? + new ChildParentPair(root, parentNode ? parentNode : root) : + key < root.key ? + this.lookupAt(this.node(root.left)!, key, root) : + this.lookupAt(this.node(root.right)!, key, root); + } + + // recursively updates the height of each node in the tree, + // and performs rotations as needed from bottom to top + // to maintain the AVL tree balance invariant + private rebalanceAt(root: AVLTreeNode, key: K): AVLTreeNode { + if (root.key > key) { + const leftChild = this.node(root.left); + if (leftChild) { + root.left = new Nullable(this.rebalanceAt(leftChild, key).id); + } + } else if (root.key < key) { + const rightChild = this.node(root.right); + if (rightChild) { + root.right = new Nullable(this.rebalanceAt(rightChild, key).id); + } + } + + this.updateHeight(root); + + return this.enforceBalance(root); + } + + // recursively checks that each node in the tree is balanced by checking that + // the AVL tree invariant holds: + // any node's left and right subtrees must have a difference in height <= 1 + private isBalanced(root: AVLTreeNode | null = this.rootNode): bool { + const b = root ? this.balance(root) : 0 + return b >= -1 && b <= 1 && + root ? this.isBalanced(this.node(root.left)) && this.isBalanced(this.node(root.right)) : + true + } +} \ No newline at end of file diff --git a/assembly/runtime/collections/index.ts b/assembly/runtime/collections/index.ts index 981be800..64653f71 100644 --- a/assembly/runtime/collections/index.ts +++ b/assembly/runtime/collections/index.ts @@ -44,4 +44,5 @@ export * from "./persistentSet"; /** @internal */ export * from "./persistentUnorderedMap"; /** @internal */ -export * from "./util"; \ No newline at end of file +export * from "./util"; +export * from "./avlTree"; diff --git a/runtime/__tests__/avl-tree-contract.spec.ts b/runtime/__tests__/avl-tree-contract.spec.ts new file mode 100644 index 00000000..1e43ebf2 --- /dev/null +++ b/runtime/__tests__/avl-tree-contract.spec.ts @@ -0,0 +1,147 @@ +import { Runtime, Account } from ".."; + +// copied and modified from https://gist.github.com/lsenta/15d7f6fcfc2987176b54 +class LittleRNG { + private seed: number; + + constructor(seed: number) { + this.seed = seed; + } + + next(): number { + this.seed = (this.seed * 9301 + 49297) % 233280; + return Math.floor((this.seed / 233281) * 4294967295); + } +} + + +let runtime: Runtime; +let avlTreeContract: Account, alice: Account; + +function has(key: number): boolean { + return alice.call_other("avlTreeContract", "has", { key }).return_data; +} + +function insert(key: number, value: number): void { + alice.call_other("avlTreeContract", "insert", { key, value }); +} + +function getSome(key: number): number { + return alice.call_other("avlTreeContract", "getSome", { key }).return_data; +} + +function remove(key: number): void { + alice.call_other("avlTreeContract", "remove", { key }); +} + +function size(): number { + return alice.call_other("avlTreeContract", "size").return_data; +} + +function isBalanced(): boolean { + return alice.call_other("avlTreeContract", "isBalanced").return_data; +} + +function height(): number { + return alice.call_other("avlTreeContract", "height").return_data; +} + +function keys(): number[] { + return alice.call_other("avlTreeContract", "keys").return_data; +} + +function values(): number[] { + return alice.call_other("avlTreeContract", "values").return_data; +} + +function random(n: number): number[] { + const rng = new LittleRNG(12345); + const keys = []; + + for (let i = 0; i < n; ++i) { + keys.push(rng.next()); + } + + return keys; +} + +function maxTreeHeight(n: number): number { + // From near-sdk-rs TreeMap: + // h <= C * log2(n + D) + B + // where: + // C =~ 1.440, D =~ 1.065, B =~ 0.328 + // (source: https://en.wikipedia.org/wiki/AVL_tree) + const B = -0.328; + const C = 1.440; + const D = 1.065; + + const h = C * Math.log2(n + D) + B; + return Math.ceil(h); +} + +function insertKeys(keysToInsert: number[], map: Map): void { + for (let i = 0; i < keysToInsert.length; ++i) { + const key = keysToInsert[i]; + expect(has(key)).toBeFalsy(); + + insert(key, i); + expect(getSome(key)).toStrictEqual(i); + + if (map) { + map.set(key, i); + expect(getSome(key)).toStrictEqual(map.get(key)); + } + } +} + +function removeKeys(keysToRemove: number[], map: Map): void { + for (let i = 0; i < keysToRemove.length; ++i) { + const key = keysToRemove[i]; + + if (map && map.has(key)) { + expect(getSome(key)).toStrictEqual(map.get(key)); + map.delete(key); + } + + remove(key); + expect(has(key)).toBeFalsy(); + } +} + +function generateRandomTree(n: number): Map { + const map = new Map(); + const keysToInsert = random(2 * n); + const keysToRemove = []; + for (let i = 0; i < n; ++i) { + keysToRemove.push(keysToInsert[i]); + } + + insertKeys(keysToInsert, map); + removeKeys(keysToRemove, map); + + return map; +} + +describe("avl tree contract calls", () => { + beforeEach(() => { + runtime = new Runtime(); + alice = runtime.newAccount("alice"); + avlTreeContract = runtime.newAccount("avlTreeContract", __dirname + "/../out/avlTreeContract.wasm"); + }); + + it("remains balanced and sorted after 2n insertions and n deletions when called in a contract", () => { + const n = 20; + const map = generateRandomTree(n); + const sortedKeys = Array.from(map.keys()).sort((a, b) => a - b); + const sortedValues = []; + for (let i = 0; i < sortedKeys.length; ++i) { + sortedValues.push(map.get(sortedKeys[i])); + } + + expect(size()).toStrictEqual(n); + expect(isBalanced()).toBeTruthy(); + expect(height()).toBeLessThanOrEqual(maxTreeHeight(n)); + expect(keys()).toStrictEqual(sortedKeys); + expect(values()).toStrictEqual(sortedValues); + }); +}); \ No newline at end of file diff --git a/runtime/asconfig.js b/runtime/asconfig.js index 053ea55c..f46a0288 100644 --- a/runtime/asconfig.js +++ b/runtime/asconfig.js @@ -17,3 +17,4 @@ function compileContract(input) { compileContract("words"); compileContract("sentences"); +compileContract("avlTreeContract"); \ No newline at end of file