-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathaugmented_ops.h
109 lines (95 loc) · 3.03 KB
/
augmented_ops.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#pragma once
#include "utils.h"
#include "map_ops.h"
// *******************************************
// AUGMENTED MAP OPERATIONS
// *******************************************
template<class Map>
struct augmented_ops : Map {
using Entry = typename Map::Entry;
using node = typename Map::node;
using ET = typename Map::ET;
using GC = typename Map::GC;
using K = typename Map::K;
using aug_t = typename Entry::aug_t;
struct aug_sum_t {
aug_t result;
aug_sum_t() : result(Entry::get_empty()) {}
void add_entry(ET e) {
result = Entry::combine(result, Entry::from_entry(e));
}
void add_aug_val(aug_t av) {
result = Entry::combine(result, av);
}
};
template<class aug>
// the sum right of or at key
static void aug_sum_right(node* b, const K& key, aug& a) {
while (b) {
if (!Map::comp(Map::get_key(b), key)) {
a.add_entry(Map::get_entry(b));
if (b->rc) a.add_aug_val(Map::aug_val(b->rc));
b = b->lc;
} else b = b->rc;
}
}
template<class aug>
// the sum left of or at key
static void aug_sum_left(node* b, const K& key, aug& a) {
while (b) {
if (!Map::comp(key, Map::get_key(b))) {
a.add_entry(Map::get_entry(b));
if (b->lc) a.add_aug_val(Map::aug_val(b->lc));
b = b->rc;
} else b = b->lc;
}
}
template<class aug>
static void aug_sum_range(node* b, const K& key_left, const K& key_right, aug& a) {
node* r = Map::range_root(b, key_left, key_right);
if (r) {
// add in left side (right of or at key_left)
aug_sum_right(r->lc, key_left, a);
a.add_entry(Map::get_entry(r)); // add in middle
// add in right side (left of or at key_right)
aug_sum_left(r->rc, key_right, a);
}
}
template<typename Func>
static node* aug_select(node* b, const Func& f) {
if (b == NULL) return NULL;
if (f(Map::aug_val(b->lc))) {
if (f(Entry::from_entry(Map::get_entry(b))))
return aug_select(b->rc, f);
return b;
} return aug_select(b->lc, f);
}
template<class Func>
static node* aug_filter(node* b, const Func& f, bool extra_ptr = false) {
if (!b) return NULL;
if (!Map::aug_val(b)) return NULL;
bool copy = extra_ptr || (b->ref_cnt > 1);
auto P = utils::fork<node*>(Map::size(b) >= utils::node_limit,
[&]() {return aug_filter(b->lc, f, copy);},
[&]() {return aug_filter(b->rc, f, copy);});
if (f(Entry::from_entry(Map::get_entry(b)))) {
return Map::node_join(P.first, P.second, GC::copy_if(b, copy, extra_ptr));
} else {
GC::dec_if(b, copy, extra_ptr);
return Map::join2(P.first, P.second);
}
}
template <class Func>
static node* insert_lazy(node* b, const ET& e, const Func& f) {
aug_t av = Entry::from_entry(e);
auto g = [&] (const aug_t& a) { return Entry::combine(av,a);};
auto lazy_join = [&] (node* l, node* r, node* m) {
m->rc = r; m->lc = l;
if (Map::is_balanced(m)) {
Map::lazy_update(m,g);
return m;
} else return Map::node_join(l,r,m);
};
return Map::insert_j(b, e, f, lazy_join, false);
}
};