Skip to content

Commit

Permalink
Merge pull request #101 from TonalidadeHidrica/range-bounds
Browse files Browse the repository at this point in the history
  • Loading branch information
qryxip authored Apr 8, 2023
2 parents c48fcaa + 686372b commit 44b3805
Show file tree
Hide file tree
Showing 7 changed files with 152 additions and 28 deletions.
2 changes: 1 addition & 1 deletion examples/library-checker-static-range-sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ fn main() {
fenwick.add(i, a);
}
for (l, r) in lrs {
println!("{}", fenwick.sum(l, r));
println!("{}", fenwick.sum(l..r));
}
}
4 changes: 2 additions & 2 deletions examples/practice2_j_segment_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ fn main() {
segtree.set(x, v);
}
2 => {
let l = input.next().unwrap().parse().unwrap();
let l: usize = input.next().unwrap().parse().unwrap();
let r: usize = input.next().unwrap().parse().unwrap();
println!("{}", segtree.prod(l, r + 1));
println!("{}", segtree.prod(l..=r));
}
3 => {
let x = input.next().unwrap().parse().unwrap();
Expand Down
10 changes: 5 additions & 5 deletions examples/practice2_k_range_affine_range_sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ fn main() {
for _ in 0..q {
match input.next().unwrap().parse().unwrap() {
0 => {
let l = input.next().unwrap().parse().unwrap();
let l: usize = input.next().unwrap().parse().unwrap();
let r = input.next().unwrap().parse().unwrap();
let b = input.next().unwrap().parse().unwrap();
let c = input.next().unwrap().parse().unwrap();
segtree.apply_range(l, r, (b, c));
segtree.apply_range(l..r, (b, c));
}
1 => {
let l = input.next().unwrap().parse().unwrap();
let r = input.next().unwrap().parse().unwrap();
println!("{}", segtree.prod(l, r).0);
let l: usize = input.next().unwrap().parse().unwrap();
let r: usize = input.next().unwrap().parse().unwrap();
println!("{}", segtree.prod(l..r).0);
}
_ => {}
}
Expand Down
4 changes: 2 additions & 2 deletions examples/practice2_l_lazy_segment_tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@ fn main() {
let l = input.next().unwrap().parse().unwrap();
let r: usize = input.next().unwrap().parse().unwrap();
match t {
1 => segtree.apply_range(l, r + 1, true),
2 => println!("{}", segtree.prod(l, r + 1).2),
1 => segtree.apply_range(l..=r, true),
2 => println!("{}", segtree.prod(l..=r).2),
_ => {}
}
}
Expand Down
29 changes: 25 additions & 4 deletions src/fenwicktree.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::ops::{Bound, RangeBounds};

// Reference: https://en.wikipedia.org/wiki/Fenwick_tree
pub struct FenwickTree<T> {
n: usize,
Expand Down Expand Up @@ -34,17 +36,29 @@ impl<T: Clone + std::ops::AddAssign<T>> FenwickTree<T> {
}
}
/// Returns data[l] + ... + data[r - 1].
pub fn sum(&self, l: usize, r: usize) -> T
pub fn sum<R>(&self, range: R) -> T
where
T: std::ops::Sub<Output = T>,
R: RangeBounds<usize>,
{
let r = match range.end_bound() {
Bound::Included(r) => r + 1,
Bound::Excluded(r) => *r,
Bound::Unbounded => self.n,
};
let l = match range.start_bound() {
Bound::Included(l) => *l,
Bound::Excluded(l) => l + 1,
Bound::Unbounded => return self.accum(r),
};
self.accum(r) - self.accum(l)
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::ops::Bound::*;

#[test]
fn fenwick_tree_works() {
Expand All @@ -53,8 +67,15 @@ mod tests {
for i in 0..5 {
bit.add(i, i as i64 + 1);
}
assert_eq!(bit.sum(0, 5), 15);
assert_eq!(bit.sum(0, 4), 10);
assert_eq!(bit.sum(1, 3), 5);
assert_eq!(bit.sum(0..5), 15);
assert_eq!(bit.sum(0..4), 10);
assert_eq!(bit.sum(1..3), 5);

assert_eq!(bit.sum(..), 15);
assert_eq!(bit.sum(..2), 3);
assert_eq!(bit.sum(..=2), 6);
assert_eq!(bit.sum(1..), 14);
assert_eq!(bit.sum(1..=3), 9);
assert_eq!(bit.sum((Excluded(0), Included(2))), 5);
}
}
79 changes: 71 additions & 8 deletions src/lazysegtree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,27 @@ impl<F: MapMonoid> LazySegtree<F> {
self.d[p].clone()
}

pub fn prod(&mut self, mut l: usize, mut r: usize) -> <F::M as Monoid>::S {
pub fn prod<R>(&mut self, range: R) -> <F::M as Monoid>::S
where
R: RangeBounds<usize>,
{
// Trivial optimization
if range.start_bound() == Bound::Unbounded && range.end_bound() == Bound::Unbounded {
return self.all_prod();
}

let mut r = match range.end_bound() {
Bound::Included(r) => r + 1,
Bound::Excluded(r) => *r,
Bound::Unbounded => self.n,
};
let mut l = match range.start_bound() {
Bound::Included(l) => *l,
Bound::Excluded(l) => l + 1,
// TODO: There are another way of optimizing [0..r)
Bound::Unbounded => 0,
};

assert!(l <= r && r <= self.n);
if l == r {
return F::identity_element();
Expand Down Expand Up @@ -124,7 +144,22 @@ impl<F: MapMonoid> LazySegtree<F> {
self.update(p >> i);
}
}
pub fn apply_range(&mut self, mut l: usize, mut r: usize, f: F::F) {
pub fn apply_range<R>(&mut self, range: R, f: F::F)
where
R: RangeBounds<usize>,
{
let mut r = match range.end_bound() {
Bound::Included(r) => r + 1,
Bound::Excluded(r) => *r,
Bound::Unbounded => self.n,
};
let mut l = match range.start_bound() {
Bound::Included(l) => *l,
Bound::Excluded(l) => l + 1,
// TODO: There are another way of optimizing [0..r)
Bound::Unbounded => 0,
};

assert!(l <= r && r <= self.n);
if l == r {
return;
Expand Down Expand Up @@ -287,7 +322,10 @@ where
}

// TODO is it useful?
use std::fmt::{Debug, Error, Formatter, Write};
use std::{
fmt::{Debug, Error, Formatter, Write},
ops::{Bound, RangeBounds},
};
impl<F> Debug for LazySegtree<F>
where
F: MapMonoid,
Expand All @@ -314,6 +352,8 @@ where

#[cfg(test)]
mod tests {
use std::ops::{Bound::*, RangeBounds};

use crate::{LazySegtree, MapMonoid, Max};

struct MaxAdd;
Expand Down Expand Up @@ -361,9 +401,13 @@ mod tests {
internal[6] = 0;
check_segtree(&internal, &mut segtree);

segtree.apply_range(3, 8, 2);
segtree.apply_range(3..8, 2);
internal[3..8].iter_mut().for_each(|e| *e += 2);
check_segtree(&internal, &mut segtree);

segtree.apply_range(2..=5, 7);
internal[2..=5].iter_mut().for_each(|e| *e += 7);
check_segtree(&internal, &mut segtree);
}

//noinspection DuplicatedCode
Expand All @@ -373,12 +417,20 @@ mod tests {
for i in 0..n {
assert_eq!(segtree.get(i), base[i]);
}

check(base, segtree, ..);
for i in 0..=n {
check(base, segtree, ..i);
check(base, segtree, i..);
if i < n {
check(base, segtree, ..=i);
}
for j in i..=n {
assert_eq!(
segtree.prod(i, j),
base[i..j].iter().max().copied().unwrap_or(i32::min_value())
);
check(base, segtree, i..j);
if j < n {
check(base, segtree, i..=j);
check(base, segtree, (Excluded(i), Included(j)));
}
}
}
assert_eq!(
Expand Down Expand Up @@ -413,4 +465,15 @@ mod tests {
}
}
}

fn check(base: &[i32], segtree: &mut LazySegtree<MaxAdd>, range: impl RangeBounds<usize>) {
let expected = base
.iter()
.enumerate()
.filter_map(|(i, a)| Some(a).filter(|_| range.contains(&i)))
.max()
.copied()
.unwrap_or(i32::min_value());
assert_eq!(segtree.prod(range), expected);
}
}
52 changes: 46 additions & 6 deletions src/segtree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use crate::internal_type_traits::{BoundedAbove, BoundedBelow, One, Zero};
use std::cmp::{max, min};
use std::convert::Infallible;
use std::marker::PhantomData;
use std::ops::{Add, Mul};
use std::ops::{Add, Bound, Mul, RangeBounds};

// TODO Should I split monoid-related traits to another module?
pub trait Monoid {
Expand Down Expand Up @@ -107,7 +107,27 @@ impl<M: Monoid> Segtree<M> {
self.d[p + self.size].clone()
}

pub fn prod(&self, mut l: usize, mut r: usize) -> M::S {
pub fn prod<R>(&self, range: R) -> M::S
where
R: RangeBounds<usize>,
{
// Trivial optimization
if range.start_bound() == Bound::Unbounded && range.end_bound() == Bound::Unbounded {
return self.all_prod();
}

let mut r = match range.end_bound() {
Bound::Included(r) => r + 1,
Bound::Excluded(r) => *r,
Bound::Unbounded => self.n,
};
let mut l = match range.start_bound() {
Bound::Included(l) => *l,
Bound::Excluded(l) => l + 1,
// TODO: There are another way of optimizing [0..r)
Bound::Unbounded => 0,
};

assert!(l <= r && r <= self.n);
let mut sml = M::identity();
let mut smr = M::identity();
Expand Down Expand Up @@ -240,6 +260,7 @@ where
mod tests {
use crate::segtree::Max;
use crate::Segtree;
use std::ops::{Bound::*, RangeBounds};

#[test]
fn test_max_segtree() {
Expand Down Expand Up @@ -272,12 +293,20 @@ mod tests {
for i in 0..n {
assert_eq!(segtree.get(i), base[i]);
}

check(base, segtree, ..);
for i in 0..=n {
check(base, segtree, ..i);
check(base, segtree, i..);
if i < n {
check(base, segtree, ..=i);
}
for j in i..=n {
assert_eq!(
segtree.prod(i, j),
base[i..j].iter().max().copied().unwrap_or(i32::min_value())
);
check(base, segtree, i..j);
if j < n {
check(base, segtree, i..=j);
check(base, segtree, (Excluded(i), Included(j)));
}
}
}
assert_eq!(
Expand Down Expand Up @@ -312,4 +341,15 @@ mod tests {
}
}
}

fn check(base: &[i32], segtree: &Segtree<Max<i32>>, range: impl RangeBounds<usize>) {
let expected = base
.iter()
.enumerate()
.filter_map(|(i, a)| Some(a).filter(|_| range.contains(&i)))
.max()
.copied()
.unwrap_or(i32::min_value());
assert_eq!(segtree.prod(range), expected);
}
}

0 comments on commit 44b3805

Please sign in to comment.