Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Bernoulli distribution #411

Merged
merged 24 commits into from
May 15, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions benches/misc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ use rand::prelude::*;
use rand::seq::*;

#[bench]
fn misc_gen_bool(b: &mut Bencher) {
fn misc_gen_bool_const(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
b.iter(|| {
// Can be evaluated at compile time.
let mut accum = true;
for _ in 0..::RAND_BENCH_N {
accum ^= rng.gen_bool(0.18);
}
black_box(accum);
accum
})
}

Expand All @@ -27,12 +28,37 @@ fn misc_gen_bool_var(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
b.iter(|| {
let mut p = 0.18;
black_box(&mut p); // Avoid constant folding.
for _ in 0..::RAND_BENCH_N {
black_box(rng.gen_bool(p));
}
})
}

#[bench]
fn misc_bernoulli_const(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
let d = rand::distributions::Bernoulli::new(0.18);
b.iter(|| {
// Can be evaluated at compile time.
let mut accum = true;
for _ in 0..::RAND_BENCH_N {
accum ^= rng.gen_bool(p);
p += 0.0001;
accum ^= rng.sample(d);
}
accum
})
}

#[bench]
fn misc_bernoulli_var(b: &mut Bencher) {
let mut rng = SmallRng::from_rng(&mut thread_rng()).unwrap();
b.iter(|| {
let mut p = 0.18;
black_box(&mut p); // Avoid constant folding.
let d = rand::distributions::Bernoulli::new(p);
for _ in 0..::RAND_BENCH_N {
black_box(rng.sample(d));
}
black_box(accum);
})
}

Expand Down
120 changes: 120 additions & 0 deletions src/distributions/bernoulli.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright 2018 The Rust Project Developers. See the COPYRIGHT
// file at the top-level directory of this distribution and at
// https://rust-lang.org/COPYRIGHT.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//! The Bernoulli distribution.

use Rng;
use distributions::Distribution;

/// The Bernoulli distribution.
///
/// This is a special case of the Binomial distribution where `n = 1`.
///
/// # Example
///
/// ```rust
/// use rand::distributions::{Bernoulli, Distribution};
///
/// let d = Bernoulli::new(0.3);
/// let v = d.sample(&mut rand::thread_rng());
/// println!("{} is from a Bernoulli distribution", v);
/// ```
///
/// # Precision
///
/// This `Bernoulli` distribution uses 64 bits from the RNG (a `u64`),
/// so only probabilities that are multiples of 2<sup>-64</sup> can be
/// represented.
#[derive(Clone, Copy, Debug)]
pub struct Bernoulli {
/// Probability of success, relative to the maximal integer.
p_int: u64,
}

impl Bernoulli {
/// Construct a new `Bernoulli` with the given probability of success `p`.
///
/// # Panics
///
/// If `p < 0` or `p > 1`.
///
/// # Precision
///
/// For `p = 1.0`, the resulting distribution will always generate true.
/// For `p = 0.0`, the resulting distribution will always generate false.
///
/// This method is accurate for any input `p` in the range `[0, 1]` which is
/// a multiple of 2<sup>-64</sup>. (Note that not all multiples of
/// 2<sup>-64</sup> in `[0, 1]` can be represented as a `f64`.)
#[inline]
pub fn new(p: f64) -> Bernoulli {
assert!((p >= 0.0) & (p <= 1.0), "Bernoulli::new not called with 0 <= p <= 0");
// Technically, this should be 2^64 or `u64::MAX + 1` because we compare
// using `<` when sampling. However, `u64::MAX` rounds to an `f64`
// larger than `u64::MAX` anyway.
const MAX_P_INT: f64 = ::core::u64::MAX as f64;
let p_int = if p < 1.0 {
(p * MAX_P_INT) as u64
} else {
// Avoid overflow: `MAX_P_INT` cannot be represented as u64.
::core::u64::MAX
};
Bernoulli { p_int }
}
}

impl Distribution<bool> for Bernoulli {
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> bool {
// Make sure to always return true for p = 1.0.
if self.p_int == ::core::u64::MAX {
return true;
}
let r: u64 = rng.gen();
r < self.p_int
}
}

#[cfg(test)]
mod test {
use Rng;
use distributions::Distribution;
use super::Bernoulli;

#[test]
fn test_trivial() {
let mut r = ::test::rng(1);
let always_false = Bernoulli::new(0.0);
let always_true = Bernoulli::new(1.0);
for _ in 0..5 {
assert_eq!(r.sample::<bool, _>(&always_false), false);
assert_eq!(r.sample::<bool, _>(&always_true), true);
assert_eq!(Distribution::<bool>::sample(&always_false, &mut r), false);
assert_eq!(Distribution::<bool>::sample(&always_true, &mut r), true);
}
}

#[test]
fn test_average() {
const P: f64 = 0.3;
let d = Bernoulli::new(P);
const N: u32 = 10_000_000;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 mil' makes this a litte slow. I guess leave it as-is for now though; maybe as part of #357 we can divide tests into two sets (fast and slow) or something.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately you need a lot of samples to get decent statistics. (The relative error probably scales with 1/sqrt(n_samples).)


let mut sum: u32 = 0;
let mut rng = ::test::rng(2);
for _ in 0..N {
if d.sample(&mut rng) {
sum += 1;
}
}
let avg = (sum as f64) / (N as f64);

assert!((avg - P).abs() < 1e-3);
}
}
12 changes: 8 additions & 4 deletions src/distributions/binomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,17 @@ use std::f64::consts::PI;
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Binomial {
n: u64, // number of trials
p: f64, // probability of success
/// Number of trials.
n: u64,
/// Probability of success.
p: f64,
}

impl Binomial {
/// Construct a new `Binomial` with the given shape parameters
/// `n`, `p`. Panics if `p <= 0` or `p >= 1`.
/// Construct a new `Binomial` with the given shape parameters `n` (number
/// of trials) and `p` (probability of success).
///
/// Panics if `p <= 0` or `p >= 1`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The p==0 and p==1 cases are trivial — should we support them? This might require extra branching however.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe, but this is orthogonal to this PR.

pub fn new(n: u64, p: f64) -> Binomial {
assert!(p > 0.0, "Binomial::new called with p <= 0");
assert!(p < 1.0, "Binomial::new called with p >= 1");
Expand Down
2 changes: 2 additions & 0 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ pub use self::uniform::Uniform as Range;
#[doc(inline)] pub use self::poisson::Poisson;
#[cfg(feature = "std")]
#[doc(inline)] pub use self::binomial::Binomial;
#[doc(inline)] pub use self::bernoulli::Bernoulli;

pub mod uniform;
#[cfg(feature="std")]
Expand All @@ -190,6 +191,7 @@ pub mod uniform;
#[doc(hidden)] pub mod poisson;
#[cfg(feature = "std")]
#[doc(hidden)] pub mod binomial;
#[doc(hidden)] pub mod bernoulli;

mod float;
mod integer;
Expand Down
24 changes: 9 additions & 15 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,6 @@ pub trait Rng: RngCore {
/// println!("{}", x);
/// println!("{:?}", rng.gen::<(f64, bool)>());
/// ```
#[inline(always)]
fn gen<T>(&mut self) -> T where Standard: Distribution<T> {
Standard.sample(self)
}
Expand Down Expand Up @@ -474,6 +473,8 @@ pub trait Rng: RngCore {

/// Return a bool with a probability `p` of being true.
///
/// This is a wrapper around [`distributions::Bernoulli`].
///
/// # Example
///
/// ```rust
Expand All @@ -483,20 +484,15 @@ pub trait Rng: RngCore {
/// println!("{}", rng.gen_bool(1.0 / 3.0));
/// ```
///
/// # Accuracy note
/// # Panics
///
/// If `p` < 0 or `p` > 1.
///
/// `gen_bool` uses 32 bits of the RNG, so if you use it to generate close
/// to or more than `2^32` results, a tiny bias may become noticable.
/// A notable consequence of the method used here is that the worst case is
/// `rng.gen_bool(0.0)`: it has a chance of 1 in `2^32` of being true, while
/// it should always be false. But using `gen_bool` to consume *many* values
/// from an RNG just to consistently generate `false` does not match with
/// the intent of this method.
/// [`distributions::Bernoulli`]: distributions/bernoulli/struct.Bernoulli.html
#[inline]
fn gen_bool(&mut self, p: f64) -> bool {
assert!(p >= 0.0 && p <= 1.0);
// If `p` is constant, this will be evaluated at compile-time.
let p_int = (p * f64::from(core::u32::MAX)) as u32;
self.gen::<u32>() <= p_int
let d = distributions::Bernoulli::new(p);
self.sample(d)
}

/// Return a random element from `values`.
Expand Down Expand Up @@ -897,7 +893,6 @@ pub fn weak_rng() -> XorShiftRng {
/// [`thread_rng`]: fn.thread_rng.html
/// [`Standard`]: distributions/struct.Standard.html
#[cfg(feature="std")]
#[inline]
pub fn random<T>() -> T where Standard: Distribution<T> {
thread_rng().gen()
}
Expand All @@ -918,7 +913,6 @@ pub fn random<T>() -> T where Standard: Distribution<T> {
/// println!("{:?}", sample);
/// ```
#[cfg(feature="std")]
#[inline(always)]
#[deprecated(since="0.4.0", note="renamed to seq::sample_iter")]
pub fn sample<T, I, R>(rng: &mut R, iterable: I, amount: usize) -> Vec<T>
where I: IntoIterator<Item=T>,
Expand Down
23 changes: 23 additions & 0 deletions tests/bool.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#![no_std]

extern crate rand;

use rand::SeedableRng;
use rand::rngs::SmallRng;
use rand::distributions::{Distribution, Bernoulli};

/// This test should make sure that we don't accidentally have undefined
/// behavior for large propabilties due to
/// https://github.com/rust-lang/rust/issues/10184.
/// Expressions like `1.0*(u64::MAX as f64) as u64` have to be avoided.
#[test]
fn large_probability() {
let p = 1. - ::core::f64::EPSILON / 2.;
assert!(p < 1.);
let d = Bernoulli::new(p);
let mut rng = SmallRng::from_seed(
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]);
for _ in 0..10 {
assert!(d.sample(&mut rng), "extremely unlikely to fail by accident");
}
}