Skip to content

Commit

Permalink
Implement choose/choose_mut/choose_from_iterator on both Rng and on s…
Browse files Browse the repository at this point in the history
…lice/Iterator

Implement Rng.choose_weighted() and Rng.choose_weighted_with_total()
  • Loading branch information
sicking committed Jun 12, 2018
1 parent ec3d7ef commit 2a786a8
Show file tree
Hide file tree
Showing 2 changed files with 396 additions and 6 deletions.
301 changes: 295 additions & 6 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,7 @@ pub mod prelude;
pub mod prng;
pub mod rngs;
#[cfg(feature = "alloc")] pub mod seq;
#[cfg(feature = "alloc")] pub use seq::{SliceRandom, IteratorRandom};

////////////////////////////////////////////////////////////////////////////////
// Compatibility re-exports. Documentation is hidden; will be removed eventually.
Expand Down Expand Up @@ -595,6 +596,161 @@ pub trait Rng: RngCore {
}
}

/// Returns one random element of the `Iterator`, or `None` if the
/// `Iterator` returns no items. If you have a slice, it's significantly
/// faster to call the [`choose`] or [`choose_mut`] functions using the
/// slice instead. However it expected to be faster than dumping the
/// Iterator into a slice and then calling [`choose`]/[`choose_mut`] on
/// the slice.
///
/// # Example
///
/// ```
/// use rand::{thread_rng, Rng};
///
/// let choices = std::iter::repeat(0)
/// .scan((1, 1), |state, _| { let (a, b) = *state; *state = (b, a+b); Some(a) })
/// .take(40);
/// let mut rng = thread_rng();
/// // Randomly choose one of the first 40 fibonacci numbers
/// println!("{}", rng.choose_from_iterator(choices).unwrap());
/// assert_eq!(rng.choose_from_iterator(std::iter::empty::<i32>()), None);
/// ```
/// [`choose`]: trait.Rng.html#method.choose
/// [`choose_mut`]: trait.Rng.html#method.choose_mut
fn choose_from_iterator<I: Iterator>(&mut self, mut iterable: I) -> Option<I::Item> {
let mut val = iterable.next();
if val.is_none() {
return val;
}

for (i, elem) in iterable.enumerate() {
if self.gen_range(0, i + 2) == 0 {
val = Some(elem);
}
}
val
}

/// Return a random element from `items` where. The chance of a given item
/// being picked, is proportional to the corresponding value in `weights`.
/// `weights` and `items` must return exactly the same number of values.
///
/// All values returned by `weights` must be `>= 0`.
///
/// This function iterates over `weights` twice. Once to get the total
/// weight, and once while choosing the random value. If you know the total
/// weight, or plan to call this function multiple times, you should
/// consider using [`choose_weighted_with_total`] instead.
///
/// Return `None` if `items` and `weights` is empty.
///
/// # Example
///
/// ```
/// use rand::{thread_rng, Rng};
///
/// let choices = ['a', 'b', 'c'];
/// let weights = [2, 1, 1];
/// let mut rng = thread_rng();
/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
/// println!("{}", rng.choose_weighted(choices.iter(), weights.iter().cloned()).unwrap());
/// ```
/// [`choose_weighted_with_total`]: trait.Rng.html#method.choose_weighted_with_total
fn choose_weighted<IterItems, IterWeights>(&mut self,
items: IterItems,
weights: IterWeights) -> Option<IterItems::Item>
where IterItems: Iterator,
IterWeights: Iterator+Clone,
IterWeights::Item: SampleUniform +
Default +
core::ops::Add<IterWeights::Item, Output=IterWeights::Item> +
core::cmp::PartialOrd<IterWeights::Item> +
Clone { // Clone is only needed for debug assertions
let total_weight: IterWeights::Item =
weights.clone().fold(Default::default(), |acc, w| {
assert!(w >= Default::default(), "Weight must be larger than zero");
acc + w
});
self.choose_weighted_with_total(items, weights, total_weight)
}

/// Return a random element from `items` where. The chance of a given item
/// being picked, is proportional to the corresponding value in `weights`.
/// `weights` and `items` must return exactly the same number of values.
///
/// All values returned by `weights` must be `>= 0`.
///
/// `total_weight` must be exactly the sum of all values returned by
/// `weights`. Builds with debug_assertions turned on will assert that this
/// equality holds. Simply storing the result of `weights.sum()` and using
/// that as `total_weight` should work.
///
/// Return `None` if `items` and `weights` is empty.
///
/// # Example
///
/// ```
/// use rand::{thread_rng, Rng};
///
/// let choices = ['a', 'b', 'c'];
/// let weights = [2, 1, 1];
/// let mut rng = thread_rng();
/// // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
/// println!("{}", rng.choose_weighted_with_total(choices.iter(), weights.iter().cloned(), 4).unwrap());
/// ```
/// [`choose_weighted_with_total`]: trait.Rng.html#method.choose_weighted_with_total
fn choose_weighted_with_total<IterItems, IterWeights>(&mut self,
mut items: IterItems,
mut weights: IterWeights,
total_weight: IterWeights::Item) -> Option<IterItems::Item>
where IterItems: Iterator,
IterWeights: Iterator,
IterWeights::Item: SampleUniform +
Default +
core::ops::Add<IterWeights::Item, Output=IterWeights::Item> +
core::cmp::PartialOrd<IterWeights::Item> +
Clone { // Clone is only needed for debug assertions

if total_weight == Default::default() {
debug_assert!(items.next().is_none());
return None;
}

// Only used when debug_assertions are turned on
let mut debug_result = None;
let debug_total_weight = if cfg!(debug_assertions) { Some(total_weight.clone()) } else { None };

let chosen_weight = self.gen_range(Default::default(), total_weight);
let mut cumulative_weight: IterWeights::Item = Default::default();

for item in items {
let weight_opt = weights.next();
assert!(weight_opt.is_some(), "`weights` returned fewer items than `items` did");
let weight = weight_opt.unwrap();
assert!(weight >= Default::default(), "Weight must be larger than zero");

cumulative_weight = cumulative_weight + weight;

if cumulative_weight > chosen_weight {
if !cfg!(debug_assertions) {
return Some(item);
}
if debug_result.is_none() {
debug_result = Some(item);
}
}
}

assert!(weights.next().is_none(), "`weights` returned more items than `items` did");
debug_assert!(debug_total_weight.unwrap() == cumulative_weight);
if cfg!(debug_assertions) && debug_result.is_some() {
return debug_result;
}

panic!("total_weight did not match up with sum of weights");
}

/// Shuffle a mutable slice in place.
///
/// This applies Durstenfeld's algorithm for the [Fisher–Yates shuffle](
Expand Down Expand Up @@ -846,6 +1002,7 @@ pub fn random<T>() -> T where Standard: Distribution<T> {
#[cfg(test)]
mod test {
use rngs::mock::StepRng;
#[cfg(feature="std")] use core::panic::catch_unwind;
use super::*;
#[cfg(all(not(feature="std"), feature="alloc"))] use alloc::boxed::Box;

Expand Down Expand Up @@ -976,15 +1133,50 @@ mod test {
#[test]
fn test_choose() {
let mut r = rng(107);
assert_eq!(r.choose(&[1, 1, 1]).map(|&x|x), Some(1));
let chars = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n'];
let mut chosen = [0i32; 14];
for _ in 0..1000 {
let picked = *r.choose(&chars).unwrap();
chosen[(picked as usize) - ('a' as usize)] += 1;
}
for count in chosen.iter() {
let err = *count - (1000 / (chars.len() as i32));
assert!(-20 <= err && err <= 20);
}

let v: &[isize] = &[];
assert_eq!(r.choose(v), None);
chosen.iter_mut().for_each(|x| *x = 0);
for _ in 0..1000 {
*r.choose_mut(&mut chosen).unwrap() += 1;
}
for count in chosen.iter() {
let err = *count - (1000 / (chosen.len() as i32));
assert!(-20 <= err && err <= 20);
}

let mut v: [isize; 0] = [];
assert_eq!(r.choose(&v), None);
assert_eq!(r.choose_mut(&mut v), None);
}

#[test]
fn test_shuffle() {
fn test_choose_from_iterator() {
let mut r = rng(108);
let mut chosen = [0i32; 9];
for _ in 0..1000 {
let picked = r.choose_from_iterator(0..9).unwrap();
chosen[picked] += 1;
}
for count in chosen.iter() {
let err = *count - 1000 / 9;
assert!(-25 <= err && err <= 25);
}

assert_eq!(r.choose_from_iterator(0..0), None);
}

#[test]
fn test_shuffle() {
let mut r = rng(109);
let empty: &mut [isize] = &mut [];
r.shuffle(empty);
let mut one = [1];
Expand All @@ -1005,7 +1197,7 @@ mod test {
#[test]
fn test_rng_trait_object() {
use distributions::{Distribution, Standard};
let mut rng = rng(109);
let mut rng = rng(110);
let mut r = &mut rng as &mut RngCore;
r.next_u32();
r.gen::<i32>();
Expand All @@ -1021,7 +1213,7 @@ mod test {
#[cfg(feature="alloc")]
fn test_rng_boxed_trait() {
use distributions::{Distribution, Standard};
let rng = rng(110);
let rng = rng(111);
let mut r = Box::new(rng) as Box<RngCore>;
r.next_u32();
r.gen::<i32>();
Expand Down Expand Up @@ -1049,6 +1241,7 @@ mod test {
}

#[test]
<<<<<<< HEAD
fn test_gen_ratio_average() {
const NUM: u32 = 3;
const DENOM: u32 = 10;
Expand All @@ -1063,5 +1256,101 @@ mod test {
}
let avg = (sum as f64) / (N as f64);
assert!((avg - (NUM as f64)/(DENOM as f64)).abs() < 1e-3);
=======
fn test_choose_weighted() {
let mut r = rng(112);
let chars = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n'];
let weights = [1u32, 2, 3, 0, 5, 6, 7, 1, 2, 3, 4, 5, 6, 7];
let total_weight = weights.iter().sum();
assert_eq!(chars.len(), weights.len());

let mut chosen = [0i32; 14];
for _ in 0..1000 {
let picked = *r.choose_weighted(chars.iter(),
weights.iter().cloned()).unwrap();
chosen[(picked as usize) - ('a' as usize)] += 1;
}
for (i, count) in chosen.iter().enumerate() {
let err = *count - ((weights[i] * 1000 / total_weight) as i32);
assert!(-25 <= err && err <= 25);
}

// Mutable items
chosen.iter_mut().for_each(|x| *x = 0);
for _ in 0..1000 {
*r.choose_weighted(chosen.iter_mut(),
weights.iter().cloned()).unwrap() += 1;
}
for (i, count) in chosen.iter().enumerate() {
let err = *count - ((weights[i] * 1000 / total_weight) as i32);
assert!(-25 <= err && err <= 25);
}

// choose_weighted_with_total
chosen.iter_mut().for_each(|x| *x = 0);
for _ in 0..1000 {
let picked = *r.choose_weighted_with_total(chars.iter(),
weights.iter().cloned(),
total_weight).unwrap();
chosen[(picked as usize) - ('a' as usize)] += 1;
}
for (i, count) in chosen.iter().enumerate() {
let err = *count - ((weights[i] * 1000 / total_weight) as i32);
assert!(-25 <= err && err <= 25);
}
}

#[test]
#[cfg(all(feature="std",
not(target_arch = "wasm32"),
not(target_arch = "asmjs")))]
fn test_choose_weighted_assertions() {
fn inner_delta(delta: i32) {
let items = vec![1, 2, 3];
let mut r = rng(113);
if cfg!(debug_assertions) || delta == 0 {
r.choose_weighted_with_total(items.iter(),
items.iter().cloned(),
6+delta);
} else {
loop {
r.choose_weighted_with_total(items.iter(),
items.iter().cloned(),
6+delta);
}
}
}

assert!(catch_unwind(|| inner_delta(0)).is_ok());
assert!(catch_unwind(|| inner_delta(1)).is_err());
assert!(catch_unwind(|| inner_delta(1000)).is_err());
if cfg!(debug_assertions) {
// The non-debug-assertions code can't detect too small total_weight
assert!(catch_unwind(|| inner_delta(-1)).is_err());
assert!(catch_unwind(|| inner_delta(-1000)).is_err());
}

fn inner_size(items: usize, weights: usize, with_total: bool) {
let mut r = rng(114);
if with_total {
r.choose_weighted_with_total(core::iter::repeat(1usize).take(items),
core::iter::repeat(1usize).take(weights),
weights);
} else {
r.choose_weighted(core::iter::repeat(1usize).take(items),
core::iter::repeat(1usize).take(weights));
}
}

assert!(catch_unwind(|| inner_size(2, 2, true)).is_ok());
assert!(catch_unwind(|| inner_size(2, 2, false)).is_ok());
assert!(catch_unwind(|| inner_size(2, 1, true)).is_err());
assert!(catch_unwind(|| inner_size(2, 1, false)).is_err());
if cfg!(debug_assertions) {
// The non-debug-assertions code can't detect too many weights
assert!(catch_unwind(|| inner_size(2, 3, true)).is_err());
assert!(catch_unwind(|| inner_size(2, 3, false)).is_err());
}
>>>>>>> Implement choose/choose_mut/choose_from_iterator on both Rng and on slice/Iterator
}
}
Loading

0 comments on commit 2a786a8

Please sign in to comment.