-
-
Notifications
You must be signed in to change notification settings - Fork 436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Bernoulli distribution #411
Merged
Merged
Changes from all commits
Commits
Show all changes
24 commits
Select commit
Hold shift + click to select a range
4213bc0
Small improvements to binomial docs
vks eb47e3b
Document panics of gen_bool
vks 20eb2bf
Implement Bernoulli distribution
vks fe02da9
Implement `Distribution<T> for Bernoulli` only for bool
vks a2921b1
Benchmark gen_bool vs. Bernoulli::sample
vks 588f0d0
Improve bool benchmarks
vks 9e234de
Move implementation from gen_bool to Bernoulli::sample
vks 9ed874f
Use generated floats in Bernoulli::sample
vks 0dfa229
Use 64 bit for generating bools
vks 5bc6b49
Add note on precision
vks e07efce
Make sure Bernoulli::new(1.) always generates true
vks 06cd389
Add copyright notice
vks e0df2c4
Bernoulli: Add comment on precision based on @pitdicker's suggestion
vks 914f54a
Bernoulli: Correct remarks about precision
vks d47d3f0
Address review feedback
vks 959160f
Don't mention `Uniform` alternative
vks bfd81c4
Fix chained comparison
vks 6fe648f
Inline a few `Rng` methods
vks 7b1e068
Add test against possible undefined behavior
vks 67ba295
Fix imports
vks 0c85fb8
Improve bool test
vks 67c55ea
Fix tests on no_std
vks fcd0154
Clarify constant in Bernoulli
vks c440d3e
Remove inline annotations from generic functions
vks File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
// Copyright 2018 The Rust Project Developers. See the COPYRIGHT | ||
// file at the top-level directory of this distribution and at | ||
// https://rust-lang.org/COPYRIGHT. | ||
// | ||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | ||
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license | ||
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your | ||
// option. This file may not be copied, modified, or distributed | ||
// except according to those terms. | ||
//! The Bernoulli distribution. | ||
|
||
use Rng; | ||
use distributions::Distribution; | ||
|
||
/// The Bernoulli distribution. | ||
/// | ||
/// This is a special case of the Binomial distribution where `n = 1`. | ||
/// | ||
/// # Example | ||
/// | ||
/// ```rust | ||
/// use rand::distributions::{Bernoulli, Distribution}; | ||
/// | ||
/// let d = Bernoulli::new(0.3); | ||
/// let v = d.sample(&mut rand::thread_rng()); | ||
/// println!("{} is from a Bernoulli distribution", v); | ||
/// ``` | ||
/// | ||
/// # Precision | ||
/// | ||
/// This `Bernoulli` distribution uses 64 bits from the RNG (a `u64`), | ||
/// so only probabilities that are multiples of 2<sup>-64</sup> can be | ||
/// represented. | ||
#[derive(Clone, Copy, Debug)] | ||
pub struct Bernoulli { | ||
/// Probability of success, relative to the maximal integer. | ||
p_int: u64, | ||
} | ||
|
||
impl Bernoulli { | ||
/// Construct a new `Bernoulli` with the given probability of success `p`. | ||
/// | ||
/// # Panics | ||
/// | ||
/// If `p < 0` or `p > 1`. | ||
/// | ||
/// # Precision | ||
/// | ||
/// For `p = 1.0`, the resulting distribution will always generate true. | ||
/// For `p = 0.0`, the resulting distribution will always generate false. | ||
/// | ||
/// This method is accurate for any input `p` in the range `[0, 1]` which is | ||
/// a multiple of 2<sup>-64</sup>. (Note that not all multiples of | ||
/// 2<sup>-64</sup> in `[0, 1]` can be represented as a `f64`.) | ||
#[inline] | ||
pub fn new(p: f64) -> Bernoulli { | ||
assert!((p >= 0.0) & (p <= 1.0), "Bernoulli::new not called with 0 <= p <= 0"); | ||
// Technically, this should be 2^64 or `u64::MAX + 1` because we compare | ||
// using `<` when sampling. However, `u64::MAX` rounds to an `f64` | ||
// larger than `u64::MAX` anyway. | ||
const MAX_P_INT: f64 = ::core::u64::MAX as f64; | ||
let p_int = if p < 1.0 { | ||
(p * MAX_P_INT) as u64 | ||
} else { | ||
// Avoid overflow: `MAX_P_INT` cannot be represented as u64. | ||
::core::u64::MAX | ||
}; | ||
Bernoulli { p_int } | ||
} | ||
} | ||
|
||
impl Distribution<bool> for Bernoulli { | ||
#[inline] | ||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> bool { | ||
// Make sure to always return true for p = 1.0. | ||
if self.p_int == ::core::u64::MAX { | ||
return true; | ||
} | ||
let r: u64 = rng.gen(); | ||
r < self.p_int | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use Rng; | ||
use distributions::Distribution; | ||
use super::Bernoulli; | ||
|
||
#[test] | ||
fn test_trivial() { | ||
let mut r = ::test::rng(1); | ||
let always_false = Bernoulli::new(0.0); | ||
let always_true = Bernoulli::new(1.0); | ||
for _ in 0..5 { | ||
assert_eq!(r.sample::<bool, _>(&always_false), false); | ||
assert_eq!(r.sample::<bool, _>(&always_true), true); | ||
assert_eq!(Distribution::<bool>::sample(&always_false, &mut r), false); | ||
assert_eq!(Distribution::<bool>::sample(&always_true, &mut r), true); | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_average() { | ||
const P: f64 = 0.3; | ||
let d = Bernoulli::new(P); | ||
const N: u32 = 10_000_000; | ||
|
||
let mut sum: u32 = 0; | ||
let mut rng = ::test::rng(2); | ||
for _ in 0..N { | ||
if d.sample(&mut rng) { | ||
sum += 1; | ||
} | ||
} | ||
let avg = (sum as f64) / (N as f64); | ||
|
||
assert!((avg - P).abs() < 1e-3); | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,13 +31,17 @@ use std::f64::consts::PI; | |
/// ``` | ||
#[derive(Clone, Copy, Debug)] | ||
pub struct Binomial { | ||
n: u64, // number of trials | ||
p: f64, // probability of success | ||
/// Number of trials. | ||
n: u64, | ||
/// Probability of success. | ||
p: f64, | ||
} | ||
|
||
impl Binomial { | ||
/// Construct a new `Binomial` with the given shape parameters | ||
/// `n`, `p`. Panics if `p <= 0` or `p >= 1`. | ||
/// Construct a new `Binomial` with the given shape parameters `n` (number | ||
/// of trials) and `p` (probability of success). | ||
/// | ||
/// Panics if `p <= 0` or `p >= 1`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The p==0 and p==1 cases are trivial — should we support them? This might require extra branching however. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe, but this is orthogonal to this PR. |
||
pub fn new(n: u64, p: f64) -> Binomial { | ||
assert!(p > 0.0, "Binomial::new called with p <= 0"); | ||
assert!(p < 1.0, "Binomial::new called with p >= 1"); | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
#![no_std] | ||
|
||
extern crate rand; | ||
|
||
use rand::SeedableRng; | ||
use rand::rngs::SmallRng; | ||
use rand::distributions::{Distribution, Bernoulli}; | ||
|
||
/// This test should make sure that we don't accidentally have undefined | ||
/// behavior for large propabilties due to | ||
/// https://github.com/rust-lang/rust/issues/10184. | ||
/// Expressions like `1.0*(u64::MAX as f64) as u64` have to be avoided. | ||
#[test] | ||
fn large_probability() { | ||
let p = 1. - ::core::f64::EPSILON / 2.; | ||
assert!(p < 1.); | ||
let d = Bernoulli::new(p); | ||
let mut rng = SmallRng::from_seed( | ||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]); | ||
for _ in 0..10 { | ||
assert!(d.sample(&mut rng), "extremely unlikely to fail by accident"); | ||
} | ||
} |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 mil' makes this a litte slow. I guess leave it as-is for now though; maybe as part of #357 we can divide tests into two sets (fast and slow) or something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately you need a lot of samples to get decent statistics. (The relative error probably scales with
1/sqrt(n_samples)
.)