-
-
Notifications
You must be signed in to change notification settings - Fork 436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Poisson: Fix undefined behavior and support f64 output #795
Changes from all commits
31aad14
c03e2c8
638b6be
eec9bba
90833ca
2553cb5
15b9a39
84d89a7
ec99801
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,8 +10,8 @@ | |
//! The Poisson distribution. | ||
|
||
use rand::Rng; | ||
use crate::{Distribution, Cauchy}; | ||
use crate::utils::log_gamma; | ||
use crate::{Distribution, Cauchy, Standard}; | ||
use crate::utils::Float; | ||
|
||
/// The Poisson distribution `Poisson(lambda)`. | ||
/// | ||
|
@@ -24,17 +24,17 @@ use crate::utils::log_gamma; | |
/// use rand_distr::{Poisson, Distribution}; | ||
/// | ||
/// let poi = Poisson::new(2.0).unwrap(); | ||
/// let v = poi.sample(&mut rand::thread_rng()); | ||
/// let v: u64 = poi.sample(&mut rand::thread_rng()); | ||
/// println!("{} is from a Poisson(2) distribution", v); | ||
/// ``` | ||
#[derive(Clone, Copy, Debug)] | ||
pub struct Poisson { | ||
lambda: f64, | ||
pub struct Poisson<N> { | ||
lambda: N, | ||
// precalculated values | ||
exp_lambda: f64, | ||
log_lambda: f64, | ||
sqrt_2lambda: f64, | ||
magic_val: f64, | ||
exp_lambda: N, | ||
log_lambda: N, | ||
sqrt_2lambda: N, | ||
magic_val: N, | ||
} | ||
|
||
/// Error type returned from `Poisson::new`. | ||
|
@@ -44,48 +44,51 @@ pub enum Error { | |
ShapeTooSmall, | ||
} | ||
|
||
impl Poisson { | ||
impl<N: Float> Poisson<N> | ||
where Standard: Distribution<N> | ||
{ | ||
/// Construct a new `Poisson` with the given shape parameter | ||
/// `lambda`. | ||
pub fn new(lambda: f64) -> Result<Poisson, Error> { | ||
if !(lambda > 0.0) { | ||
pub fn new(lambda: N) -> Result<Poisson<N>, Error> { | ||
if !(lambda > N::from(0.0)) { | ||
return Err(Error::ShapeTooSmall); | ||
} | ||
let log_lambda = lambda.ln(); | ||
Ok(Poisson { | ||
lambda, | ||
exp_lambda: (-lambda).exp(), | ||
log_lambda, | ||
sqrt_2lambda: (2.0 * lambda).sqrt(), | ||
magic_val: lambda * log_lambda - log_gamma(1.0 + lambda), | ||
sqrt_2lambda: (N::from(2.0) * lambda).sqrt(), | ||
magic_val: lambda * log_lambda - (N::from(1.0) + lambda).log_gamma(), | ||
}) | ||
} | ||
} | ||
|
||
impl Distribution<u64> for Poisson { | ||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 { | ||
impl<N: Float> Distribution<N> for Poisson<N> | ||
where Standard: Distribution<N> | ||
{ | ||
#[inline] | ||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N { | ||
// using the algorithm from Numerical Recipes in C | ||
|
||
// for low expected values use the Knuth method | ||
if self.lambda < 12.0 { | ||
let mut result = 0; | ||
let mut p = 1.0; | ||
if self.lambda < N::from(12.0) { | ||
let mut result = N::from(0.); | ||
let mut p = N::from(1.0); | ||
while p > self.exp_lambda { | ||
p *= rng.gen::<f64>(); | ||
result += 1; | ||
p *= rng.gen::<N>(); | ||
result += N::from(1.); | ||
} | ||
result - 1 | ||
result - N::from(1.) | ||
} | ||
// high expected values - rejection method | ||
else { | ||
let mut int_result: u64; | ||
|
||
// we use the Cauchy distribution as the comparison distribution | ||
// f(x) ~ 1/(1+x^2) | ||
let cauchy = Cauchy::new(0.0, 1.0).unwrap(); | ||
let cauchy = Cauchy::new(N::from(0.0), N::from(1.0)).unwrap(); | ||
let mut result; | ||
|
||
loop { | ||
let mut result; | ||
let mut comp_dev; | ||
|
||
loop { | ||
|
@@ -94,32 +97,41 @@ impl Distribution<u64> for Poisson { | |
// shift the peak of the comparison ditribution | ||
result = self.sqrt_2lambda * comp_dev + self.lambda; | ||
// repeat the drawing until we are in the range of possible values | ||
if result >= 0.0 { | ||
if result >= N::from(0.0) { | ||
break; | ||
} | ||
} | ||
// now the result is a random variable greater than 0 with Cauchy distribution | ||
// the result should be an integer value | ||
result = result.floor(); | ||
int_result = result as u64; | ||
|
||
// this is the ratio of the Poisson distribution to the comparison distribution | ||
// the magic value scales the distribution function to a range of approximately 0-1 | ||
// since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1 | ||
// this doesn't change the resulting distribution, only increases the rate of failed drawings | ||
let check = 0.9 * (1.0 + comp_dev * comp_dev) | ||
* (result * self.log_lambda - log_gamma(1.0 + result) - self.magic_val).exp(); | ||
let check = N::from(0.9) * (N::from(1.0) + comp_dev * comp_dev) | ||
* (result * self.log_lambda - (N::from(1.0) + result).log_gamma() - self.magic_val).exp(); | ||
|
||
// check with uniform random value - if below the threshold, we are within the target distribution | ||
if rng.gen::<f64>() <= check { | ||
if rng.gen::<N>() <= check { | ||
break; | ||
} | ||
} | ||
int_result | ||
result | ||
} | ||
} | ||
} | ||
|
||
impl<N: Float> Distribution<u64> for Poisson<N> | ||
where Standard: Distribution<N> | ||
{ | ||
#[inline] | ||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 { | ||
vks marked this conversation as resolved.
Show resolved
Hide resolved
|
||
let result: N = self.sample(rng); | ||
result.to_u64().unwrap() | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use crate::Distribution; | ||
|
@@ -129,27 +141,82 @@ mod test { | |
fn test_poisson_10() { | ||
let poisson = Poisson::new(10.0).unwrap(); | ||
let mut rng = crate::test::rng(123); | ||
let mut sum = 0; | ||
let mut sum_u64 = 0; | ||
let mut sum_f64 = 0.; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the goal of this test? We know both values should be the same since (a) both should be non-negative integers and (b) we should not be going anywhere close to the limits/accuracy of either type. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The goal is to test that the code path works as expected, not the precision. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's the same code path, aside from the extra cast. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would like to trigger the path with the cast, making sure it works. I think it is important because of the potential undefined behavior. |
||
for _ in 0..1000 { | ||
sum += poisson.sample(&mut rng); | ||
let s_u64: u64 = poisson.sample(&mut rng); | ||
let s_f64: f64 = poisson.sample(&mut rng); | ||
sum_u64 += s_u64; | ||
sum_f64 += s_f64; | ||
} | ||
let avg_u64 = (sum_u64 as f64) / 1000.0; | ||
let avg_f64 = sum_f64 / 1000.0; | ||
println!("Poisson averages: {} (u64) {} (f64)", avg_u64, avg_f64); | ||
for &avg in &[avg_u64, avg_f64] { | ||
assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough | ||
} | ||
let avg = (sum as f64) / 1000.0; | ||
println!("Poisson average: {}", avg); | ||
assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough | ||
} | ||
|
||
#[test] | ||
fn test_poisson_15() { | ||
// Take the 'high expected values' path | ||
let poisson = Poisson::new(15.0).unwrap(); | ||
let mut rng = crate::test::rng(123); | ||
let mut sum = 0; | ||
let mut sum_u64 = 0; | ||
let mut sum_f64 = 0.; | ||
for _ in 0..1000 { | ||
sum += poisson.sample(&mut rng); | ||
let s_u64: u64 = poisson.sample(&mut rng); | ||
let s_f64: f64 = poisson.sample(&mut rng); | ||
sum_u64 += s_u64; | ||
sum_f64 += s_f64; | ||
} | ||
let avg_u64 = (sum_u64 as f64) / 1000.0; | ||
let avg_f64 = sum_f64 / 1000.0; | ||
println!("Poisson average: {} (u64) {} (f64)", avg_u64, avg_f64); | ||
for &avg in &[avg_u64, avg_f64] { | ||
assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_poisson_10_f32() { | ||
let poisson = Poisson::new(10.0f32).unwrap(); | ||
let mut rng = crate::test::rng(123); | ||
let mut sum_u64 = 0; | ||
let mut sum_f32 = 0.; | ||
for _ in 0..1000 { | ||
let s_u64: u64 = poisson.sample(&mut rng); | ||
let s_f32: f32 = poisson.sample(&mut rng); | ||
sum_u64 += s_u64; | ||
sum_f32 += s_f32; | ||
} | ||
let avg_u64 = (sum_u64 as f32) / 1000.0; | ||
let avg_f32 = sum_f32 / 1000.0; | ||
println!("Poisson averages: {} (u64) {} (f32)", avg_u64, avg_f32); | ||
for &avg in &[avg_u64, avg_f32] { | ||
assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough | ||
} | ||
} | ||
|
||
#[test] | ||
fn test_poisson_15_f32() { | ||
// Take the 'high expected values' path | ||
let poisson = Poisson::new(15.0f32).unwrap(); | ||
let mut rng = crate::test::rng(123); | ||
let mut sum_u64 = 0; | ||
let mut sum_f32 = 0.; | ||
for _ in 0..1000 { | ||
let s_u64: u64 = poisson.sample(&mut rng); | ||
let s_f32: f32 = poisson.sample(&mut rng); | ||
sum_u64 += s_u64; | ||
sum_f32 += s_f32; | ||
} | ||
let avg_u64 = (sum_u64 as f32) / 1000.0; | ||
let avg_f32 = sum_f32 / 1000.0; | ||
println!("Poisson average: {} (u64) {} (f32)", avg_u64, avg_f32); | ||
for &avg in &[avg_u64, avg_f32] { | ||
assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough | ||
} | ||
let avg = (sum as f64) / 1000.0; | ||
println!("Poisson average: {}", avg); | ||
assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough | ||
} | ||
|
||
#[test] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a large function — I don't think
#[inline]
makes any sense here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not
#[inline(always)]
, so the compiler is still free to make that choice. It might make sense to inline it into theu64
sampling, so that some bound checks can be eliminated.