Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic distributions: use custom trait #793

Merged
merged 9 commits into from
May 15, 2019
1 change: 0 additions & 1 deletion rand_distr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,3 @@ appveyor = { repository = "rust-random/rand" }

[dependencies]
rand = { path = "..", version = ">=0.5, <=0.7" }
num-traits = "0.2"
30 changes: 17 additions & 13 deletions rand_distr/src/cauchy.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
//! The Cauchy distribution.

use rand::Rng;
use crate::Distribution;
use std::f64::consts::PI;
use crate::{Distribution, Standard};
use crate::utils::Float;

/// The Cauchy distribution `Cauchy(median, scale)`.
///
Expand All @@ -28,9 +28,9 @@ use std::f64::consts::PI;
/// println!("{} is from a Cauchy(2, 5) distribution", v);
/// ```
#[derive(Clone, Copy, Debug)]
pub struct Cauchy {
median: f64,
scale: f64
pub struct Cauchy<N> {
median: N,
scale: N,
}

/// Error type returned from `Cauchy::new`.
Expand All @@ -40,11 +40,13 @@ pub enum Error {
ScaleTooSmall,
}

impl Cauchy {
impl<N: Float> Cauchy<N>
where Standard: Distribution<N>
{
/// Construct a new `Cauchy` with the given shape parameters
/// `median` the peak location and `scale` the scale factor.
pub fn new(median: f64, scale: f64) -> Result<Cauchy, Error> {
if !(scale > 0.0) {
pub fn new(median: N, scale: N) -> Result<Cauchy<N>, Error> {
if !(scale > N::from(0.0)) {
return Err(Error::ScaleTooSmall);
}
Ok(Cauchy {
Expand All @@ -54,13 +56,15 @@ impl Cauchy {
}
}

impl Distribution<f64> for Cauchy {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> f64 {
impl<N: Float> Distribution<N> for Cauchy<N>
where Standard: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> N {
// sample from [0, 1)
let x = rng.gen::<f64>();
let x = Standard.sample(rng);
// get standard cauchy random number
// note that π/2 is not exactly representable, even if x=0.5 the result is finite
let comp_dev = (PI * x).tan();
let comp_dev = (N::pi() * x).tan();
// shift and scale according to parameters
let result = self.median + self.scale * comp_dev;
result
Expand Down Expand Up @@ -99,7 +103,7 @@ mod test {
fn test_cauchy_mean() {
let cauchy = Cauchy::new(10.0, 5.0).unwrap();
let mut rng = crate::test::rng(123);
let mut sum = 0.0;
let mut sum = 0.0f64;
for _ in 0..1000 {
sum += cauchy.sample(&mut rng);
}
Expand Down
34 changes: 19 additions & 15 deletions rand_distr/src/dirichlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
//! The dirichlet distribution.

use rand::Rng;
use crate::Distribution;
use crate::gamma::Gamma;
use crate::{Distribution, Gamma, StandardNormal, Exp1, Open01};
use crate::utils::Float;

/// The dirichelet distribution `Dirichlet(alpha)`.
///
Expand All @@ -30,9 +30,9 @@ use crate::gamma::Gamma;
/// println!("{:?} is from a Dirichlet([1.0, 2.0, 3.0]) distribution", samples);
/// ```
#[derive(Clone, Debug)]
pub struct Dirichlet {
pub struct Dirichlet<N> {
/// Concentration parameters (alpha)
alpha: Vec<f64>,
alpha: Vec<N>,
}

/// Error type returned from `Dirchlet::new`.
Expand All @@ -46,18 +46,20 @@ pub enum Error {
SizeTooSmall,
}

impl Dirichlet {
impl<N: Float> Dirichlet<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
/// Construct a new `Dirichlet` with the given alpha parameter `alpha`.
///
/// Requires `alpha.len() >= 2`.
#[inline]
pub fn new<V: Into<Vec<f64>>>(alpha: V) -> Result<Dirichlet, Error> {
pub fn new<V: Into<Vec<N>>>(alpha: V) -> Result<Dirichlet<N>, Error> {
let a = alpha.into();
if a.len() < 2 {
return Err(Error::AlphaTooShort);
}
for i in 0..a.len() {
if !(a[i] > 0.0) {
if !(a[i] > N::from(0.0)) {
return Err(Error::AlphaTooSmall);
}
}
Expand All @@ -69,8 +71,8 @@ impl Dirichlet {
///
/// Requires `size >= 2`.
#[inline]
pub fn new_with_size(alpha: f64, size: usize) -> Result<Dirichlet, Error> {
if !(alpha > 0.0) {
pub fn new_with_size(alpha: N, size: usize) -> Result<Dirichlet<N>, Error> {
if !(alpha > N::from(0.0)) {
return Err(Error::AlphaTooSmall);
}
if size < 2 {
Expand All @@ -82,18 +84,20 @@ impl Dirichlet {
}
}

impl Distribution<Vec<f64>> for Dirichlet {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<f64> {
impl<N: Float> Distribution<Vec<N>> for Dirichlet<N>
where StandardNormal: Distribution<N>, Exp1: Distribution<N>, Open01: Distribution<N>
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<N> {
let n = self.alpha.len();
let mut samples = vec![0.0f64; n];
let mut sum = 0.0f64;
let mut samples = vec![N::from(0.0); n];
let mut sum = N::from(0.0);

for i in 0..n {
let g = Gamma::new(self.alpha[i], 1.0).unwrap();
let g = Gamma::new(self.alpha[i], N::from(1.0)).unwrap();
samples[i] = g.sample(rng);
sum += samples[i];
}
let invacc = 1.0 / sum;
let invacc = N::from(1.0) / sum;
for i in 0..n {
samples[i] *= invacc;
}
Expand Down
7 changes: 3 additions & 4 deletions rand_distr/src/exponential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@

use rand::Rng;
use crate::{ziggurat_tables, Distribution};
use crate::utils::ziggurat;
use num_traits::Float;
use crate::utils::{ziggurat, Float};

/// Samples floating-point numbers according to the exponential distribution,
/// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or
Expand Down Expand Up @@ -105,10 +104,10 @@ where Exp1: Distribution<N>
/// `lambda`.
#[inline]
pub fn new(lambda: N) -> Result<Exp<N>, Error> {
if !(lambda > N::zero()) {
if !(lambda > N::from(0.0)) {
return Err(Error::LambdaTooSmall);
}
Ok(Exp { lambda_inverse: N::one() / lambda })
Ok(Exp { lambda_inverse: N::from(1.0) / lambda })
}
}

Expand Down
Loading