diff --git a/examples/library-checker-static-range-sum.rs b/examples/library-checker-static-range-sum.rs index db38391..e8414cc 100644 --- a/examples/library-checker-static-range-sum.rs +++ b/examples/library-checker-static-range-sum.rs @@ -20,6 +20,6 @@ fn main() { fenwick.add(i, a); } for (l, r) in lrs { - println!("{}", fenwick.sum(l, r)); + println!("{}", fenwick.sum(l..r)); } } diff --git a/examples/practice2_j_segment_tree.rs b/examples/practice2_j_segment_tree.rs index ed5d043..44876a1 100644 --- a/examples/practice2_j_segment_tree.rs +++ b/examples/practice2_j_segment_tree.rs @@ -20,9 +20,9 @@ fn main() { segtree.set(x, v); } 2 => { - let l = input.next().unwrap().parse().unwrap(); + let l: usize = input.next().unwrap().parse().unwrap(); let r: usize = input.next().unwrap().parse().unwrap(); - println!("{}", segtree.prod(l, r + 1)); + println!("{}", segtree.prod(l..=r)); } 3 => { let x = input.next().unwrap().parse().unwrap(); diff --git a/examples/practice2_k_range_affine_range_sum.rs b/examples/practice2_k_range_affine_range_sum.rs index e07c542..20500a4 100644 --- a/examples/practice2_k_range_affine_range_sum.rs +++ b/examples/practice2_k_range_affine_range_sum.rs @@ -50,16 +50,16 @@ fn main() { for _ in 0..q { match input.next().unwrap().parse().unwrap() { 0 => { - let l = input.next().unwrap().parse().unwrap(); + let l: usize = input.next().unwrap().parse().unwrap(); let r = input.next().unwrap().parse().unwrap(); let b = input.next().unwrap().parse().unwrap(); let c = input.next().unwrap().parse().unwrap(); - segtree.apply_range(l, r, (b, c)); + segtree.apply_range(l..r, (b, c)); } 1 => { - let l = input.next().unwrap().parse().unwrap(); - let r = input.next().unwrap().parse().unwrap(); - println!("{}", segtree.prod(l, r).0); + let l: usize = input.next().unwrap().parse().unwrap(); + let r: usize = input.next().unwrap().parse().unwrap(); + println!("{}", segtree.prod(l..r).0); } _ => {} } diff --git a/examples/practice2_l_lazy_segment_tree.rs b/examples/practice2_l_lazy_segment_tree.rs index 61469d4..2958320 100644 --- a/examples/practice2_l_lazy_segment_tree.rs +++ b/examples/practice2_l_lazy_segment_tree.rs @@ -55,8 +55,8 @@ fn main() { let l = input.next().unwrap().parse().unwrap(); let r: usize = input.next().unwrap().parse().unwrap(); match t { - 1 => segtree.apply_range(l, r + 1, true), - 2 => println!("{}", segtree.prod(l, r + 1).2), + 1 => segtree.apply_range(l..=r, true), + 2 => println!("{}", segtree.prod(l..=r).2), _ => {} } } diff --git a/src/fenwicktree.rs b/src/fenwicktree.rs index 9256ff0..cfcf32a 100644 --- a/src/fenwicktree.rs +++ b/src/fenwicktree.rs @@ -1,3 +1,5 @@ +use std::ops::{Bound, RangeBounds}; + // Reference: https://en.wikipedia.org/wiki/Fenwick_tree pub struct FenwickTree { n: usize, @@ -34,10 +36,21 @@ impl> FenwickTree { } } /// Returns data[l] + ... + data[r - 1]. - pub fn sum(&self, l: usize, r: usize) -> T + pub fn sum(&self, range: R) -> T where T: std::ops::Sub, + R: RangeBounds, { + let r = match range.end_bound() { + Bound::Included(r) => r + 1, + Bound::Excluded(r) => *r, + Bound::Unbounded => self.n, + }; + let l = match range.start_bound() { + Bound::Included(l) => *l, + Bound::Excluded(l) => l + 1, + Bound::Unbounded => return self.accum(r), + }; self.accum(r) - self.accum(l) } } @@ -45,6 +58,7 @@ impl> FenwickTree { #[cfg(test)] mod tests { use super::*; + use std::ops::Bound::*; #[test] fn fenwick_tree_works() { @@ -53,8 +67,15 @@ mod tests { for i in 0..5 { bit.add(i, i as i64 + 1); } - assert_eq!(bit.sum(0, 5), 15); - assert_eq!(bit.sum(0, 4), 10); - assert_eq!(bit.sum(1, 3), 5); + assert_eq!(bit.sum(0..5), 15); + assert_eq!(bit.sum(0..4), 10); + assert_eq!(bit.sum(1..3), 5); + + assert_eq!(bit.sum(..), 15); + assert_eq!(bit.sum(..2), 3); + assert_eq!(bit.sum(..=2), 6); + assert_eq!(bit.sum(1..), 14); + assert_eq!(bit.sum(1..=3), 9); + assert_eq!(bit.sum((Excluded(0), Included(2))), 5); } } diff --git a/src/lazysegtree.rs b/src/lazysegtree.rs index 47020a8..d430b34 100644 --- a/src/lazysegtree.rs +++ b/src/lazysegtree.rs @@ -73,7 +73,27 @@ impl LazySegtree { self.d[p].clone() } - pub fn prod(&mut self, mut l: usize, mut r: usize) -> ::S { + pub fn prod(&mut self, range: R) -> ::S + where + R: RangeBounds, + { + // Trivial optimization + if range.start_bound() == Bound::Unbounded && range.end_bound() == Bound::Unbounded { + return self.all_prod(); + } + + let mut r = match range.end_bound() { + Bound::Included(r) => r + 1, + Bound::Excluded(r) => *r, + Bound::Unbounded => self.n, + }; + let mut l = match range.start_bound() { + Bound::Included(l) => *l, + Bound::Excluded(l) => l + 1, + // TODO: There are another way of optimizing [0..r) + Bound::Unbounded => 0, + }; + assert!(l <= r && r <= self.n); if l == r { return F::identity_element(); @@ -124,7 +144,22 @@ impl LazySegtree { self.update(p >> i); } } - pub fn apply_range(&mut self, mut l: usize, mut r: usize, f: F::F) { + pub fn apply_range(&mut self, range: R, f: F::F) + where + R: RangeBounds, + { + let mut r = match range.end_bound() { + Bound::Included(r) => r + 1, + Bound::Excluded(r) => *r, + Bound::Unbounded => self.n, + }; + let mut l = match range.start_bound() { + Bound::Included(l) => *l, + Bound::Excluded(l) => l + 1, + // TODO: There are another way of optimizing [0..r) + Bound::Unbounded => 0, + }; + assert!(l <= r && r <= self.n); if l == r { return; @@ -287,7 +322,10 @@ where } // TODO is it useful? -use std::fmt::{Debug, Error, Formatter, Write}; +use std::{ + fmt::{Debug, Error, Formatter, Write}, + ops::{Bound, RangeBounds}, +}; impl Debug for LazySegtree where F: MapMonoid, @@ -314,6 +352,8 @@ where #[cfg(test)] mod tests { + use std::ops::{Bound::*, RangeBounds}; + use crate::{LazySegtree, MapMonoid, Max}; struct MaxAdd; @@ -361,9 +401,13 @@ mod tests { internal[6] = 0; check_segtree(&internal, &mut segtree); - segtree.apply_range(3, 8, 2); + segtree.apply_range(3..8, 2); internal[3..8].iter_mut().for_each(|e| *e += 2); check_segtree(&internal, &mut segtree); + + segtree.apply_range(2..=5, 7); + internal[2..=5].iter_mut().for_each(|e| *e += 7); + check_segtree(&internal, &mut segtree); } //noinspection DuplicatedCode @@ -373,12 +417,20 @@ mod tests { for i in 0..n { assert_eq!(segtree.get(i), base[i]); } + + check(base, segtree, ..); for i in 0..=n { + check(base, segtree, ..i); + check(base, segtree, i..); + if i < n { + check(base, segtree, ..=i); + } for j in i..=n { - assert_eq!( - segtree.prod(i, j), - base[i..j].iter().max().copied().unwrap_or(i32::min_value()) - ); + check(base, segtree, i..j); + if j < n { + check(base, segtree, i..=j); + check(base, segtree, (Excluded(i), Included(j))); + } } } assert_eq!( @@ -413,4 +465,15 @@ mod tests { } } } + + fn check(base: &[i32], segtree: &mut LazySegtree, range: impl RangeBounds) { + let expected = base + .iter() + .enumerate() + .filter_map(|(i, a)| Some(a).filter(|_| range.contains(&i))) + .max() + .copied() + .unwrap_or(i32::min_value()); + assert_eq!(segtree.prod(range), expected); + } } diff --git a/src/segtree.rs b/src/segtree.rs index b543aa3..573b9ff 100644 --- a/src/segtree.rs +++ b/src/segtree.rs @@ -3,7 +3,7 @@ use crate::internal_type_traits::{BoundedAbove, BoundedBelow, One, Zero}; use std::cmp::{max, min}; use std::convert::Infallible; use std::marker::PhantomData; -use std::ops::{Add, Mul}; +use std::ops::{Add, Bound, Mul, RangeBounds}; // TODO Should I split monoid-related traits to another module? pub trait Monoid { @@ -107,7 +107,27 @@ impl Segtree { self.d[p + self.size].clone() } - pub fn prod(&self, mut l: usize, mut r: usize) -> M::S { + pub fn prod(&self, range: R) -> M::S + where + R: RangeBounds, + { + // Trivial optimization + if range.start_bound() == Bound::Unbounded && range.end_bound() == Bound::Unbounded { + return self.all_prod(); + } + + let mut r = match range.end_bound() { + Bound::Included(r) => r + 1, + Bound::Excluded(r) => *r, + Bound::Unbounded => self.n, + }; + let mut l = match range.start_bound() { + Bound::Included(l) => *l, + Bound::Excluded(l) => l + 1, + // TODO: There are another way of optimizing [0..r) + Bound::Unbounded => 0, + }; + assert!(l <= r && r <= self.n); let mut sml = M::identity(); let mut smr = M::identity(); @@ -240,6 +260,7 @@ where mod tests { use crate::segtree::Max; use crate::Segtree; + use std::ops::{Bound::*, RangeBounds}; #[test] fn test_max_segtree() { @@ -272,12 +293,20 @@ mod tests { for i in 0..n { assert_eq!(segtree.get(i), base[i]); } + + check(base, segtree, ..); for i in 0..=n { + check(base, segtree, ..i); + check(base, segtree, i..); + if i < n { + check(base, segtree, ..=i); + } for j in i..=n { - assert_eq!( - segtree.prod(i, j), - base[i..j].iter().max().copied().unwrap_or(i32::min_value()) - ); + check(base, segtree, i..j); + if j < n { + check(base, segtree, i..=j); + check(base, segtree, (Excluded(i), Included(j))); + } } } assert_eq!( @@ -312,4 +341,15 @@ mod tests { } } } + + fn check(base: &[i32], segtree: &Segtree>, range: impl RangeBounds) { + let expected = base + .iter() + .enumerate() + .filter_map(|(i, a)| Some(a).filter(|_| range.contains(&i))) + .max() + .copied() + .unwrap_or(i32::min_value()); + assert_eq!(segtree.prod(range), expected); + } }