Skip to content

Commit

Permalink
Merge pull request #795 from vks/fp-poisson
Browse files Browse the repository at this point in the history
Poisson: Fix undefined behavior and support f64 output
  • Loading branch information
dhardy authored May 16, 2019
2 parents 4ef40b6 + ec99801 commit c0b8722
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 43 deletions.
151 changes: 109 additions & 42 deletions rand_distr/src/poisson.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
//! The Poisson distribution.
use rand::Rng;
use crate::{Distribution, Cauchy};
use crate::utils::log_gamma;
use crate::{Distribution, Cauchy, Standard};
use crate::utils::Float;

/// The Poisson distribution `Poisson(lambda)`.
///
Expand All @@ -24,17 +24,17 @@ use crate::utils::log_gamma;
/// use rand_distr::{Poisson, Distribution};
///
/// let poi = Poisson::new(2.0).unwrap();
/// let v = poi.sample(&mut rand::thread_rng());
/// let v: u64 = poi.sample(&mut rand::thread_rng());
/// println!("{} is from a Poisson(2) distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Poisson {
lambda: f64,
pub struct Poisson<N> {
lambda: N,
// precalculated values
exp_lambda: f64,
log_lambda: f64,
sqrt_2lambda: f64,
magic_val: f64,
exp_lambda: N,
log_lambda: N,
sqrt_2lambda: N,
magic_val: N,
}

/// Error type returned from `Poisson::new`.
Expand All @@ -44,48 +44,51 @@ pub enum Error {
ShapeTooSmall,
}

impl Poisson {
impl<N: Float> Poisson<N>
where Standard: Distribution<N>
{
/// Construct a new `Poisson` with the given shape parameter
/// `lambda`.
pub fn new(lambda: f64) -> Result<Poisson, Error> {
if !(lambda > 0.0) {
pub fn new(lambda: N) -> Result<Poisson<N>, Error> {
if !(lambda > N::from(0.0)) {
return Err(Error::ShapeTooSmall);
}
let log_lambda = lambda.ln();
Ok(Poisson {
lambda,
exp_lambda: (-lambda).exp(),
log_lambda,
sqrt_2lambda: (2.0 * lambda).sqrt(),
magic_val: lambda * log_lambda - log_gamma(1.0 + lambda),
sqrt_2lambda: (N::from(2.0) * lambda).sqrt(),
magic_val: lambda * log_lambda - (N::from(1.0) + lambda).log_gamma(),
})
}
}

impl Distribution<u64> for Poisson {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
impl<N: Float> Distribution<N> for Poisson<N>
where Standard: Distribution<N>
{
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
// using the algorithm from Numerical Recipes in C

// for low expected values use the Knuth method
if self.lambda < 12.0 {
let mut result = 0;
let mut p = 1.0;
if self.lambda < N::from(12.0) {
let mut result = N::from(0.);
let mut p = N::from(1.0);
while p > self.exp_lambda {
p *= rng.gen::<f64>();
result += 1;
p *= rng.gen::<N>();
result += N::from(1.);
}
result - 1
result - N::from(1.)
}
// high expected values - rejection method
else {
let mut int_result: u64;

// we use the Cauchy distribution as the comparison distribution
// f(x) ~ 1/(1+x^2)
let cauchy = Cauchy::new(0.0, 1.0).unwrap();
let cauchy = Cauchy::new(N::from(0.0), N::from(1.0)).unwrap();
let mut result;

loop {
let mut result;
let mut comp_dev;

loop {
Expand All @@ -94,32 +97,41 @@ impl Distribution<u64> for Poisson {
// shift the peak of the comparison ditribution
result = self.sqrt_2lambda * comp_dev + self.lambda;
// repeat the drawing until we are in the range of possible values
if result >= 0.0 {
if result >= N::from(0.0) {
break;
}
}
// now the result is a random variable greater than 0 with Cauchy distribution
// the result should be an integer value
result = result.floor();
int_result = result as u64;

// this is the ratio of the Poisson distribution to the comparison distribution
// the magic value scales the distribution function to a range of approximately 0-1
// since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1
// this doesn't change the resulting distribution, only increases the rate of failed drawings
let check = 0.9 * (1.0 + comp_dev * comp_dev)
* (result * self.log_lambda - log_gamma(1.0 + result) - self.magic_val).exp();
let check = N::from(0.9) * (N::from(1.0) + comp_dev * comp_dev)
* (result * self.log_lambda - (N::from(1.0) + result).log_gamma() - self.magic_val).exp();

// check with uniform random value - if below the threshold, we are within the target distribution
if rng.gen::<f64>() <= check {
if rng.gen::<N>() <= check {
break;
}
}
int_result
result
}
}
}

impl<N: Float> Distribution<u64> for Poisson<N>
where Standard: Distribution<N>
{
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
let result: N = self.sample(rng);
result.to_u64().unwrap()
}
}

#[cfg(test)]
mod test {
use crate::Distribution;
Expand All @@ -129,27 +141,82 @@ mod test {
fn test_poisson_10() {
let poisson = Poisson::new(10.0).unwrap();
let mut rng = crate::test::rng(123);
let mut sum = 0;
let mut sum_u64 = 0;
let mut sum_f64 = 0.;
for _ in 0..1000 {
sum += poisson.sample(&mut rng);
let s_u64: u64 = poisson.sample(&mut rng);
let s_f64: f64 = poisson.sample(&mut rng);
sum_u64 += s_u64;
sum_f64 += s_f64;
}
let avg_u64 = (sum_u64 as f64) / 1000.0;
let avg_f64 = sum_f64 / 1000.0;
println!("Poisson averages: {} (u64) {} (f64)", avg_u64, avg_f64);
for &avg in &[avg_u64, avg_f64] {
assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough
}
let avg = (sum as f64) / 1000.0;
println!("Poisson average: {}", avg);
assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough
}

#[test]
fn test_poisson_15() {
// Take the 'high expected values' path
let poisson = Poisson::new(15.0).unwrap();
let mut rng = crate::test::rng(123);
let mut sum = 0;
let mut sum_u64 = 0;
let mut sum_f64 = 0.;
for _ in 0..1000 {
sum += poisson.sample(&mut rng);
let s_u64: u64 = poisson.sample(&mut rng);
let s_f64: f64 = poisson.sample(&mut rng);
sum_u64 += s_u64;
sum_f64 += s_f64;
}
let avg_u64 = (sum_u64 as f64) / 1000.0;
let avg_f64 = sum_f64 / 1000.0;
println!("Poisson average: {} (u64) {} (f64)", avg_u64, avg_f64);
for &avg in &[avg_u64, avg_f64] {
assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough
}
}

#[test]
fn test_poisson_10_f32() {
let poisson = Poisson::new(10.0f32).unwrap();
let mut rng = crate::test::rng(123);
let mut sum_u64 = 0;
let mut sum_f32 = 0.;
for _ in 0..1000 {
let s_u64: u64 = poisson.sample(&mut rng);
let s_f32: f32 = poisson.sample(&mut rng);
sum_u64 += s_u64;
sum_f32 += s_f32;
}
let avg_u64 = (sum_u64 as f32) / 1000.0;
let avg_f32 = sum_f32 / 1000.0;
println!("Poisson averages: {} (u64) {} (f32)", avg_u64, avg_f32);
for &avg in &[avg_u64, avg_f32] {
assert!((avg - 10.0).abs() < 0.5); // not 100% certain, but probable enough
}
}

#[test]
fn test_poisson_15_f32() {
// Take the 'high expected values' path
let poisson = Poisson::new(15.0f32).unwrap();
let mut rng = crate::test::rng(123);
let mut sum_u64 = 0;
let mut sum_f32 = 0.;
for _ in 0..1000 {
let s_u64: u64 = poisson.sample(&mut rng);
let s_f32: f32 = poisson.sample(&mut rng);
sum_u64 += s_u64;
sum_f32 += s_f32;
}
let avg_u64 = (sum_u64 as f32) / 1000.0;
let avg_f32 = sum_f32 / 1000.0;
println!("Poisson average: {} (u64) {} (f32)", avg_u64, avg_f32);
for &avg in &[avg_u64, avg_f32] {
assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough
}
let avg = (sum as f64) / 1000.0;
println!("Poisson average: {}", avg);
assert!((avg - 15.0).abs() < 0.5); // not 100% certain, but probable enough
}

#[test]
Expand Down
53 changes: 52 additions & 1 deletion rand_distr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,13 @@ pub trait Float: Copy + Sized + cmp::PartialOrd
fn pi() -> Self;
/// Support approximate representation of a f64 value
fn from(x: f64) -> Self;
/// Support converting to an unsigned integer.
fn to_u64(self) -> Option<u64>;

/// Take the absolute value of self
fn abs(self) -> Self;
/// Take the largest integer less than or equal to self
fn floor(self) -> Self;

/// Take the exponential of self
fn exp(self) -> Self;
Expand All @@ -48,34 +52,81 @@ pub trait Float: Copy + Sized + cmp::PartialOrd

/// Take the tangent of self
fn tan(self) -> Self;
/// Take the logarithm of the gamma function of self
fn log_gamma(self) -> Self;
}

impl Float for f32 {
#[inline]
fn pi() -> Self { core::f32::consts::PI }
#[inline]
fn from(x: f64) -> Self { x as f32 }
#[inline]
fn to_u64(self) -> Option<u64> {
if self >= 0. && self <= ::core::u64::MAX as f32 {
Some(self as u64)
} else {
None
}
}

#[inline]
fn abs(self) -> Self { self.abs() }
#[inline]
fn floor(self) -> Self { self.floor() }

#[inline]
fn exp(self) -> Self { self.exp() }
#[inline]
fn ln(self) -> Self { self.ln() }
#[inline]
fn sqrt(self) -> Self { self.sqrt() }
#[inline]
fn powf(self, power: Self) -> Self { self.powf(power) }

#[inline]
fn tan(self) -> Self { self.tan() }
#[inline]
fn log_gamma(self) -> Self {
let result = log_gamma(self as f64);
assert!(result <= ::core::f32::MAX as f64);
assert!(result >= ::core::f32::MIN as f64);
result as f32
}
}

impl Float for f64 {
#[inline]
fn pi() -> Self { core::f64::consts::PI }
#[inline]
fn from(x: f64) -> Self { x }
#[inline]
fn to_u64(self) -> Option<u64> {
if self >= 0. && self <= ::core::u64::MAX as f64 {
Some(self as u64)
} else {
None
}
}

#[inline]
fn abs(self) -> Self { self.abs() }
#[inline]
fn floor(self) -> Self { self.floor() }

#[inline]
fn exp(self) -> Self { self.exp() }
#[inline]
fn ln(self) -> Self { self.ln() }
#[inline]
fn sqrt(self) -> Self { self.sqrt() }
#[inline]
fn powf(self, power: Self) -> Self { self.powf(power) }

#[inline]
fn tan(self) -> Self { self.tan() }
#[inline]
fn log_gamma(self) -> Self { log_gamma(self) }
}

/// Calculates ln(gamma(x)) (natural logarithm of the gamma
Expand Down Expand Up @@ -109,7 +160,7 @@ pub(crate) fn log_gamma(x: f64) -> f64 {
// the first few terms of the series for Ag(x)
let mut a = 1.000000000190015;
let mut denom = x;
for coeff in &coefficients {
for &coeff in &coefficients {
denom += 1.0;
a += coeff / denom;
}
Expand Down

0 comments on commit c0b8722

Please sign in to comment.