-
-
Notifications
You must be signed in to change notification settings - Fork 436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WeightedChoice could use the Walker or Vose Alias method for O(1) sampling (instead of O(log n)) #601
Comments
Thanks. For now I'm more interested in getting the API right, and this isn't a core function, so hasn't the highest priority IMO. |
Yeah, definitely. That said, unfortunately I suspect an implementation of this may require changing the current API (both the old |
I created an implementation for myself. I will put it here so that anyone who needs it can use it. It is based on this and it should be possible to tweak it to use integer weights by scaling what is 100%. use rand::distributions::Distribution;
use rand::Rng;
use std::collections::VecDeque;
pub struct AliasMethodWeightedIndex {
aliases: Vec<usize>,
no_alias_odds: Vec<f64>,
}
impl AliasMethodWeightedIndex {
pub fn new(weights: Vec<f64>) -> Self {
debug_assert!(weights.iter().all(|&w| w >= 0.0));
let weight_sum = pairwise_sum_f64(weights.as_slice());
if !weight_sum.is_finite() {
panic!("Sum of weights not finite.");
}
let n = weights.len();
let mut no_alias_odds = weights;
for p in no_alias_odds.iter_mut() {
*p *= n as f64 / weight_sum;
}
// Split indices into indices with small weights and indices with big weights.
// Instead of two `Vec` with unknown capacity we use a single `VecDeque` with
// known capacity. Front represents smalls and back represents bigs. We also
// need to keep track of the size of each virtual `Vec`.
let mut smalls_bigs = VecDeque::with_capacity(n);
let mut smalls_len = 0_usize;
let mut bigs_len = 0_usize;
for (index, &weight) in no_alias_odds.iter().enumerate() {
if weight < 1.0 {
smalls_bigs.push_front(index);
smalls_len += 1;
} else {
smalls_bigs.push_back(index);
bigs_len += 1;
}
}
let mut aliases = vec![0; n];
while smalls_len > 0 && bigs_len > 0 {
let s = smalls_bigs.pop_front().unwrap();
smalls_len -= 1;
let b = smalls_bigs.pop_back().unwrap();
bigs_len -= 1;
aliases[s] = b;
no_alias_odds[b] = no_alias_odds[s] + no_alias_odds[b] - 1.0;
if no_alias_odds[b] < 1.0 {
smalls_bigs.push_front(b);
smalls_len += 1;
} else {
smalls_bigs.push_back(b);
bigs_len += 1;
}
}
// The remaining indices should have no alias odds of about 1. This is due to
// numeric accuracy. Otherwise they would be exactly 1.
for index in smalls_bigs.into_iter() {
// Because p = 1 we don't need to set an alias. It will never be accessed.
no_alias_odds[index] = 1.0;
}
Self {
aliases,
no_alias_odds,
}
}
}
impl Distribution<usize> for AliasMethodWeightedIndex {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
let candidate = rng.gen_range(0, self.no_alias_odds.len());
if rng.gen_bool(self.no_alias_odds[candidate]) {
candidate
} else {
self.aliases[candidate]
}
}
}
pub fn pairwise_sum_f64(values: &[f64]) -> f64 {
if values.len() <= 32 {
values.iter().sum()
} else {
let mid = values.len() / 2;
let (a, b) = values.split_at(mid);
pairwise_sum_f64(a) + pairwise_sum_f64(b)
}
} Note that I used a pairwise summation algorithm to improve accuracy when there are many floating point weights. I benchmarked it to find a good size for the base case and found that it is about twice as fast as simple loop/iterator summation on my machine. I don't know why that is, because I would have expected a little bit of overhead instead but I haven't investigated further. |
@zroug thanks for the code. We would welcome a PR if you can see how to integrate this. |
I believe we can close this now |
Hm, it doesn't look to me like the code has changed (e.g. #692 hasn't landed), so this doesn't seem fixed? Is the issue-management approach written down somewhere (so I can understand)? |
No, the code hasn't landed yet, but we have a PR in reasonable shape. Do we need a tracking issue too? Maybe I should just follow "standard" practice and keep this open for now then. |
#692 is merged now. |
https://en.wikipedia.org/wiki/Alias_method allows for O(1) sampling of variates with still only O(n) of preprocessing, which is significantly better than the O(log n) of the current method (binary search, sometimes called Roulette Wheel Selection).
The text was updated successfully, but these errors were encountered: