From a379504cd80394a3fd7ae2badaa4da0ca033cb96 Mon Sep 17 00:00:00 2001 From: Stein Somers Date: Mon, 11 Mar 2019 00:44:24 +0100 Subject: [PATCH] improve worst-case performance of BTreeSet intersection --- src/liballoc/benches/btree/set.rs | 150 ++++++++++++++++++-------- src/liballoc/benches/lib.rs | 1 + src/liballoc/collections/btree/set.rs | 96 ++++++++++++----- 3 files changed, 175 insertions(+), 72 deletions(-) diff --git a/src/liballoc/benches/btree/set.rs b/src/liballoc/benches/btree/set.rs index 08e1db5fbb74d..bc12c3d6982d2 100644 --- a/src/liballoc/benches/btree/set.rs +++ b/src/liballoc/benches/btree/set.rs @@ -1,36 +1,35 @@ use std::collections::BTreeSet; +use std::collections::btree_set::Intersection; use rand::{thread_rng, Rng}; use test::{black_box, Bencher}; -fn random(n1: u32, n2: u32) -> [BTreeSet; 2] { +fn random(n1: usize, n2: usize) -> [BTreeSet; 2] { let mut rng = thread_rng(); - let mut set1 = BTreeSet::new(); - let mut set2 = BTreeSet::new(); - for _ in 0..n1 { - let i = rng.gen::(); - set1.insert(i); - } - for _ in 0..n2 { - let i = rng.gen::(); - set2.insert(i); + let mut sets = [BTreeSet::new(), BTreeSet::new()]; + for i in 0..2 { + while sets[i].len() < [n1, n2][i] { + sets[i].insert(rng.gen()); + } } - [set1, set2] + assert_eq!(sets[0].len(), n1); + assert_eq!(sets[1].len(), n2); + sets } -fn staggered(n1: u32, n2: u32) -> [BTreeSet; 2] { - let mut even = BTreeSet::new(); - let mut odd = BTreeSet::new(); - for i in 0..n1 { - even.insert(i * 2); - } - for i in 0..n2 { - odd.insert(i * 2 + 1); +fn stagger(n1: usize, factor: usize) -> [BTreeSet; 2] { + let n2 = n1 * factor; + let mut sets = [BTreeSet::new(), BTreeSet::new()]; + for i in 0..(n1 + n2) { + let b = i % (factor + 1) != 0; + sets[b as usize].insert(i as u32); } - [even, odd] + assert_eq!(sets[0].len(), n1); + assert_eq!(sets[1].len(), n2); + sets } -fn neg_vs_pos(n1: u32, n2: u32) -> [BTreeSet; 2] { +fn neg_vs_pos(n1: usize, n2: usize) -> [BTreeSet; 2] { let mut neg = BTreeSet::new(); let mut pos = BTreeSet::new(); for i in -(n1 as i32)..=-1 { @@ -39,22 +38,38 @@ fn neg_vs_pos(n1: u32, n2: u32) -> [BTreeSet; 2] { for i in 1..=(n2 as i32) { pos.insert(i); } + assert_eq!(neg.len(), n1); + assert_eq!(pos.len(), n2); [neg, pos] } -fn pos_vs_neg(n1: u32, n2: u32) -> [BTreeSet; 2] { - let mut neg = BTreeSet::new(); - let mut pos = BTreeSet::new(); - for i in -(n1 as i32)..=-1 { - neg.insert(i); +fn pos_vs_neg(n1: usize, n2: usize) -> [BTreeSet; 2] { + let mut sets = neg_vs_pos(n2, n1); + sets.reverse(); + assert_eq!(sets[0].len(), n1); + assert_eq!(sets[1].len(), n2); + sets +} + +fn intersection_search(sets: &[BTreeSet; 2]) -> Intersection + where T: std::cmp::Ord +{ + Intersection::Search { + a_iter: sets[0].iter(), + b_set: &sets[1], } - for i in 1..=(n2 as i32) { - pos.insert(i); +} + +fn intersection_stitch(sets: &[BTreeSet; 2]) -> Intersection + where T: std::cmp::Ord +{ + Intersection::Stitch { + a_iter: sets[0].iter(), + b_iter: sets[1].iter(), } - [pos, neg] } -macro_rules! set_intersection_bench { +macro_rules! intersection_bench { ($name: ident, $sets: expr) => { #[bench] pub fn $name(b: &mut Bencher) { @@ -68,21 +83,64 @@ macro_rules! set_intersection_bench { }) } }; + ($name: ident, $sets: expr, $intersection_kind: ident) => { + #[bench] + pub fn $name(b: &mut Bencher) { + // setup + let sets = $sets; + assert!(sets[0].len() >= 1); + assert!(sets[1].len() >= sets[0].len()); + + // measure + b.iter(|| { + let x = $intersection_kind(&sets).count(); + black_box(x); + }) + } + }; } -set_intersection_bench! {intersect_random_100, random(100, 100)} -set_intersection_bench! {intersect_random_10k, random(10_000, 10_000)} -set_intersection_bench! {intersect_random_10_vs_10k, random(10, 10_000)} -set_intersection_bench! {intersect_random_10k_vs_10, random(10_000, 10)} -set_intersection_bench! {intersect_staggered_100, staggered(100, 100)} -set_intersection_bench! {intersect_staggered_10k, staggered(10_000, 10_000)} -set_intersection_bench! {intersect_staggered_10_vs_10k, staggered(10, 10_000)} -set_intersection_bench! {intersect_staggered_10k_vs_10, staggered(10_000, 10)} -set_intersection_bench! {intersect_neg_vs_pos_100, neg_vs_pos(100, 100)} -set_intersection_bench! {intersect_neg_vs_pos_10k, neg_vs_pos(10_000, 10_000)} -set_intersection_bench! {intersect_neg_vs_pos_10_vs_10k,neg_vs_pos(10, 10_000)} -set_intersection_bench! {intersect_neg_vs_pos_10k_vs_10,neg_vs_pos(10_000, 10)} -set_intersection_bench! {intersect_pos_vs_neg_100, pos_vs_neg(100, 100)} -set_intersection_bench! {intersect_pos_vs_neg_10k, pos_vs_neg(10_000, 10_000)} -set_intersection_bench! {intersect_pos_vs_neg_10_vs_10k,pos_vs_neg(10, 10_000)} -set_intersection_bench! {intersect_pos_vs_neg_10k_vs_10,pos_vs_neg(10_000, 10)} +intersection_bench! {intersect_100_neg_vs_100_pos, neg_vs_pos(100, 100)} +intersection_bench! {intersect_100_neg_vs_10k_pos, neg_vs_pos(100, 10_000)} +intersection_bench! {intersect_100_pos_vs_100_neg, pos_vs_neg(100, 100)} +intersection_bench! {intersect_100_pos_vs_10k_neg, pos_vs_neg(100, 10_000)} +intersection_bench! {intersect_10k_neg_vs_100_pos, neg_vs_pos(10_000, 100)} +intersection_bench! {intersect_10k_neg_vs_10k_pos, neg_vs_pos(10_000, 10_000)} +intersection_bench! {intersect_10k_pos_vs_100_neg, pos_vs_neg(10_000, 100)} +intersection_bench! {intersect_10k_pos_vs_10k_neg, pos_vs_neg(10_000, 10_000)} +intersection_bench! {intersect_random_100_vs_100_actual,random(100, 100)} +intersection_bench! {intersect_random_100_vs_100_search,random(100, 100), intersection_search} +intersection_bench! {intersect_random_100_vs_100_stitch,random(100, 100), intersection_stitch} +intersection_bench! {intersect_random_100_vs_10k_actual,random(100, 10_000)} +intersection_bench! {intersect_random_100_vs_10k_search,random(100, 10_000), intersection_search} +intersection_bench! {intersect_random_100_vs_10k_stitch,random(100, 10_000), intersection_stitch} +intersection_bench! {intersect_random_10k_vs_10k_actual,random(10_000, 10_000)} +intersection_bench! {intersect_random_10k_vs_10k_search,random(10_000, 10_000), intersection_search} +intersection_bench! {intersect_random_10k_vs_10k_stitch,random(10_000, 10_000), intersection_stitch} +intersection_bench! {intersect_stagger_100_actual, stagger(100, 1)} +intersection_bench! {intersect_stagger_100_search, stagger(100, 1), intersection_search} +intersection_bench! {intersect_stagger_100_stitch, stagger(100, 1), intersection_stitch} +intersection_bench! {intersect_stagger_10k_actual, stagger(10_000, 1)} +intersection_bench! {intersect_stagger_10k_search, stagger(10_000, 1), intersection_search} +intersection_bench! {intersect_stagger_10k_stitch, stagger(10_000, 1), intersection_stitch} +intersection_bench! {intersect_stagger_1_actual, stagger(1, 1)} +intersection_bench! {intersect_stagger_1_search, stagger(1, 1), intersection_search} +intersection_bench! {intersect_stagger_1_stitch, stagger(1, 1), intersection_stitch} +intersection_bench! {intersect_stagger_diff1_actual, stagger(100, 1 << 1)} +intersection_bench! {intersect_stagger_diff1_search, stagger(100, 1 << 1), intersection_search} +intersection_bench! {intersect_stagger_diff1_stitch, stagger(100, 1 << 1), intersection_stitch} +intersection_bench! {intersect_stagger_diff2_actual, stagger(100, 1 << 2)} +intersection_bench! {intersect_stagger_diff2_search, stagger(100, 1 << 2), intersection_search} +intersection_bench! {intersect_stagger_diff2_stitch, stagger(100, 1 << 2), intersection_stitch} +intersection_bench! {intersect_stagger_diff3_actual, stagger(100, 1 << 3)} +intersection_bench! {intersect_stagger_diff3_search, stagger(100, 1 << 3), intersection_search} +intersection_bench! {intersect_stagger_diff3_stitch, stagger(100, 1 << 3), intersection_stitch} +intersection_bench! {intersect_stagger_diff4_actual, stagger(100, 1 << 4)} +intersection_bench! {intersect_stagger_diff4_search, stagger(100, 1 << 4), intersection_search} +intersection_bench! {intersect_stagger_diff4_stitch, stagger(100, 1 << 4), intersection_stitch} +intersection_bench! {intersect_stagger_diff5_actual, stagger(100, 1 << 5)} +intersection_bench! {intersect_stagger_diff5_search, stagger(100, 1 << 5), intersection_search} +intersection_bench! {intersect_stagger_diff5_stitch, stagger(100, 1 << 5), intersection_stitch} +intersection_bench! {intersect_stagger_diff6_actual, stagger(100, 1 << 6)} +intersection_bench! {intersect_stagger_diff6_search, stagger(100, 1 << 6), intersection_search} +intersection_bench! {intersect_stagger_diff6_stitch, stagger(100, 1 << 6), intersection_stitch} diff --git a/src/liballoc/benches/lib.rs b/src/liballoc/benches/lib.rs index 4bf5ec10c41e7..c9cf318cc07df 100644 --- a/src/liballoc/benches/lib.rs +++ b/src/liballoc/benches/lib.rs @@ -1,5 +1,6 @@ #![feature(repr_simd)] #![feature(test)] +#![feature(benches_btree_set)] extern crate test; diff --git a/src/liballoc/collections/btree/set.rs b/src/liballoc/collections/btree/set.rs index 2be6455ad5903..6d92693b153f6 100644 --- a/src/liballoc/collections/btree/set.rs +++ b/src/liballoc/collections/btree/set.rs @@ -3,7 +3,7 @@ use core::borrow::Borrow; use core::cmp::Ordering::{self, Less, Greater, Equal}; -use core::cmp::{min, max}; +use core::cmp::max; use core::fmt::{self, Debug}; use core::iter::{Peekable, FromIterator, FusedIterator}; use core::ops::{BitOr, BitAnd, BitXor, Sub, RangeBounds}; @@ -163,18 +163,34 @@ impl fmt::Debug for SymmetricDifference<'_, T> { /// [`BTreeSet`]: struct.BTreeSet.html /// [`intersection`]: struct.BTreeSet.html#method.intersection #[stable(feature = "rust1", since = "1.0.0")] -pub struct Intersection<'a, T: 'a> { - a: Peekable>, - b: Peekable>, +pub enum Intersection<'a, T: 'a> { + #[doc(hidden)] + #[unstable(feature = "benches_btree_set", reason = "benchmarks for pull #58577", issue = "0")] + Stitch { + a_iter: Iter<'a, T>, // for size_hint, should be the smaller of the sets + b_iter: Iter<'a, T>, + }, + #[doc(hidden)] + #[unstable(feature = "benches_btree_set", reason = "benchmarks for pull #58577", issue = "0")] + Search { + a_iter: Iter<'a, T>, // for size_hint, should be the smaller of the sets + b_set: &'a BTreeSet, + }, } #[stable(feature = "collection_debug", since = "1.17.0")] impl fmt::Debug for Intersection<'_, T> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_tuple("Intersection") - .field(&self.a) - .field(&self.b) - .finish() + match self { + Intersection::Stitch { a_iter, b_iter } => f + .debug_tuple("Intersection") + .field(&a_iter) + .field(&b_iter) + .finish(), + Intersection::Search { a_iter, b_set: _ } => { + f.debug_tuple("Intersection").field(&a_iter).finish() + } + } } } @@ -326,9 +342,22 @@ impl BTreeSet { /// ``` #[stable(feature = "rust1", since = "1.0.0")] pub fn intersection<'a>(&'a self, other: &'a BTreeSet) -> Intersection<'a, T> { - Intersection { - a: self.iter().peekable(), - b: other.iter().peekable(), + let (a_set, b_set) = if self.len() <= other.len() { + (self, other) + } else { + (other, self) + }; + if a_set.len() > b_set.len() / 16 { + Intersection::Stitch { + a_iter: a_set.iter(), + b_iter: b_set.iter(), + } + } else { + // Iterate small set only and find matches in large set. + Intersection::Search { + a_iter: a_set.iter(), + b_set, + } } } @@ -1072,9 +1101,15 @@ impl FusedIterator for SymmetricDifference<'_, T> {} #[stable(feature = "rust1", since = "1.0.0")] impl Clone for Intersection<'_, T> { fn clone(&self) -> Self { - Intersection { - a: self.a.clone(), - b: self.b.clone(), + match self { + Intersection::Stitch { a_iter, b_iter } => Intersection::Stitch { + a_iter: a_iter.clone(), + b_iter: b_iter.clone(), + }, + Intersection::Search { a_iter, b_set } => Intersection::Search { + a_iter: a_iter.clone(), + b_set, + }, } } } @@ -1083,24 +1118,33 @@ impl<'a, T: Ord> Iterator for Intersection<'a, T> { type Item = &'a T; fn next(&mut self) -> Option<&'a T> { - loop { - match Ord::cmp(self.a.peek()?, self.b.peek()?) { - Less => { - self.a.next(); - } - Equal => { - self.b.next(); - return self.a.next(); - } - Greater => { - self.b.next(); + match self { + Intersection::Stitch { a_iter, b_iter } => { + let mut a_next = a_iter.next()?; + let mut b_next = b_iter.next()?; + loop { + match Ord::cmp(a_next, b_next) { + Less => a_next = a_iter.next()?, + Greater => b_next = b_iter.next()?, + Equal => return Some(a_next), + } } } + Intersection::Search { a_iter, b_set } => loop { + let a_next = a_iter.next()?; + if b_set.contains(&a_next) { + return Some(a_next); + } + }, } } fn size_hint(&self) -> (usize, Option) { - (0, Some(min(self.a.len(), self.b.len()))) + let max_size = match self { + Intersection::Stitch { a_iter, .. } => a_iter.len(), + Intersection::Search { a_iter, .. } => a_iter.len(), + }; + (0, Some(max_size)) } }