From f089a537dcbda449bacc78c6b1578f9aa5979024 Mon Sep 17 00:00:00 2001 From: Jonas Sicking Date: Sun, 24 Jun 2018 17:59:14 -0700 Subject: [PATCH] Implement HighPrecision for f32 and f64 --- benches/distributions.rs | 3 + src/distributions/float.rs | 515 ++++++++++++++++++++----------------- 2 files changed, 284 insertions(+), 234 deletions(-) diff --git a/benches/distributions.rs b/benches/distributions.rs index 5945e75a3e..a2933c0869 100644 --- a/benches/distributions.rs +++ b/benches/distributions.rs @@ -84,6 +84,9 @@ distr_int!(distr_uniform_i128, i128, Uniform::new(-123_456_789_123i128, 123_456_ distr_float!(distr_uniform_f32, f32, Uniform::new(2.26f32, 2.319)); distr_float!(distr_uniform_f64, f64, Uniform::new(2.26f64, 2.319)); +distr_float!(distr_highprecision1_f32, f32, HighPrecision::new(2.26f32, 2.319)); +distr_float!(distr_highprecision2_f32, f32, HighPrecision::new(-1.0f32 / 3.0, 2.319)); +distr_float!(distr_highprecision3_f32, f32, HighPrecision::new(0.001f32, 123_456_789_012_345.987)); distr_float!(distr_highprecision1_f64, f64, HighPrecision::new(2.26f64, 2.319)); distr_float!(distr_highprecision2_f64, f64, HighPrecision::new(-1.0f64 / 3.0, 2.319)); distr_float!(distr_highprecision3_f64, f64, HighPrecision::new(0.001f64, 123_456_789_012_345.987)); diff --git a/src/distributions/float.rs b/src/distributions/float.rs index 3576df17e7..522c9f3270 100644 --- a/src/distributions/float.rs +++ b/src/distributions/float.rs @@ -68,6 +68,31 @@ pub struct OpenClosed01; #[derive(Clone, Copy, Debug)] pub struct Open01; +/// A distribution to do high precision sampling of floating point numbers +/// uniformly in a given range. This is similar to Uniform, but samples +/// numbers with the full precision of the floating-point type used, at the +/// cost of slower performance. +#[derive(Clone, Copy, Debug)] +pub struct HighPrecision where F: HPFloatHelper { + low_as_int: F::SignedInt, + high_as_int: F::SignedInt, + exponent: u16, + distribution: F::SignedIntDistribution, +} + +impl HighPrecision { + /// Create a new HighPrecision distribution. Sampling from this + /// distribution will return values `>= low` and `< high`. + pub fn new(low: F, high: F) -> Self { + let parsed = F::parse_new(low, high); + HighPrecision { + low_as_int: parsed.0, + high_as_int: parsed.1, + exponent: parsed.2, + distribution: parsed.3, + } + } +} pub(crate) trait IntoFloat { type F; @@ -83,8 +108,15 @@ pub(crate) trait IntoFloat { fn into_float_with_exponent(self, exponent: i32) -> Self::F; } +pub trait HPFloatHelper: Sized { + type SignedInt; + type SignedIntDistribution; + fn parse_new(low: Self, high: Self) -> + (Self::SignedInt, Self::SignedInt, u16, Self::SignedIntDistribution); +} + macro_rules! float_impls { - ($ty:ty, $uty:ty, $fraction_bits:expr, $exponent_bias:expr) => { + ($ty:ty, $uty:ty, $ity:ty, $fraction_bits:expr, $exponent_bits:expr, $exponent_bias:expr) => { impl IntoFloat for $uty { type F = $ty; #[inline(always)] @@ -139,175 +171,166 @@ macro_rules! float_impls { fraction.into_float_with_exponent(0) - (1.0 - EPSILON / 2.0) } } - } -} -float_impls! { f32, u32, 23, 127 } -float_impls! { f64, u64, 52, 1023 } -/// A distribution to do high precision sampling of floating point numbers -/// uniformly in a given range. This is similar to Uniform, but samples -/// numbers with the full precision of the floating-point type used, at the -/// cost of slower performance. -#[derive(Clone, Copy, Debug)] -pub struct HighPrecision { - low_as_int: i64, - high_as_int: i64, - exponent: u16, - distribution: ::Sampler, -} + impl HPFloatHelper for $ty { + type SignedInt = $ity; + type SignedIntDistribution = <$ity as SampleUniform>::Sampler; -impl HighPrecision { - /// Returns a new HighPrecision distribution in the `[low, high)` range. - pub fn new(low: f64, high: f64) -> Self { - fn bitmask(bits: u64) -> u64 { - if bits >= 64 { -1i64 as u64 } else { (1u64 << bits) - 1u64 } - } - fn round_neg_inf_shr(val: i64, n: u16) -> i64 { - if n < 64 { - val >> n - } else if val >= 0 { - 0 - } else { - -1 - } - } - fn round_pos_inf_shr(val: i64, n: u16) -> i64 { - -round_neg_inf_shr(-val, n) - } - fn parse(val: f64) -> (i64, u16, i64) { - let bits = val.to_bits(); - let mut mant = (bits & bitmask(52)) as i64; - let mut exp = ((bits >> 52) & bitmask(11)) as u16; - let neg = (bits >> 63) == 1; - let mut as_int = (bits & bitmask(63)) as i64; - if exp != 0 { - mant |= 1i64 << 52; - } else { - exp = 1; - } - if neg { - mant *= -1; - as_int *= -1; - } - (mant, exp, as_int) - } + fn parse_new(low: $ty, high: $ty) -> + ($ity, $ity, u16, <$ity as SampleUniform>::Sampler) { + fn bitmask(bits: $uty) -> $uty { + if bits >= ::core::mem::size_of::<$uty>() as $uty * 8 { (-1 as $ity) as $uty } else { (1 as $uty << bits) - 1 } + } + fn round_neg_inf_shr(val: $ity, n: u16) -> $ity { + if n < ::core::mem::size_of::<$ity>() as u16 * 8 { + val >> n + } else if val >= 0 { + 0 + } else { + -1 + } + } + fn round_pos_inf_shr(val: $ity, n: u16) -> $ity { + -round_neg_inf_shr(-val, n) + } + fn parse(val: $ty) -> ($ity, u16, $ity) { + let bits = val.to_bits(); + let mut mant = (bits & bitmask($fraction_bits)) as $ity; + let mut exp = ((bits >> $fraction_bits) & bitmask($exponent_bits)) as u16; + let neg = (bits >> ($fraction_bits + $exponent_bits)) == 1; + let mut as_int = (bits & bitmask($fraction_bits + $exponent_bits)) as $ity; + if exp != 0 { + mant |= 1 as $ity << $fraction_bits; + } else { + exp = 1; + } + if neg { + mant *= -1; + as_int *= -1; + } + (mant, exp, as_int) + } - assert!(low.is_finite() && high.is_finite(), "HighPrecision::new called with non-finite limit"); - assert!(low < high, "HighPrecision::new called with low >= high"); + assert!(low.is_finite() && high.is_finite(), "HighPrecision::new called with non-finite limit"); + assert!(low < high, "HighPrecision::new called with low >= high"); - let (mut mant_low, exp_low, low_as_int) = parse(low); - let (mut mant_high, exp_high, high_as_int) = parse(high); + let (mut mant_low, exp_low, low_as_int) = parse(low); + let (mut mant_high, exp_high, high_as_int) = parse(high); - let exp; - if exp_high >= exp_low { - exp = exp_high; - mant_low = round_neg_inf_shr(mant_low, exp_high - exp_low); - } else { - exp = exp_low; - mant_high = round_pos_inf_shr(mant_high, exp_low - exp_high); - } + let exp; + if exp_high >= exp_low { + exp = exp_high; + mant_low = round_neg_inf_shr(mant_low, exp_high - exp_low); + } else { + exp = exp_low; + mant_high = round_pos_inf_shr(mant_high, exp_low - exp_high); + } - HighPrecision { - low_as_int, - high_as_int, - exponent: exp, - distribution: ::Sampler::new(mant_low, mant_high), + (low_as_int, + high_as_int, + exp, + <$ity as SampleUniform>::Sampler::new(mant_low, mant_high)) + } } - } -} -impl Distribution for HighPrecision { - fn sample(&self, rng: &mut R) -> f64 { - fn bitmask(bits: u64) -> u64 { - if bits >= 64 { -1i64 as u64 } else { (1u64 << bits) - 1u64 } - } - loop { - let signed_mant = self.distribution.sample(rng); - // Operate on the absolute value so that we can count bit-sizes - // correctly - let is_neg = signed_mant < 0; - let mut mantissa = signed_mant.abs() as u64; - - // If the resulting mantissa has too few bits, arithmetically add additional - // bits from rng. When `mant` represents a negative number, this means - // subtracting the generated bits. - const GOAL_ZEROS: u16 = 64 - 53; - let mut exp = self.exponent; - let mut zeros = mantissa.leading_zeros() as u16; - while zeros > GOAL_ZEROS && exp > 1 { - let additional = ::core::cmp::min(zeros - GOAL_ZEROS, exp - 1); - let additional_bits = rng.gen::() >> (64 - additional); - mantissa <<= additional; - if !is_neg { - mantissa |= additional_bits; - } else { - mantissa -= additional_bits; + impl Distribution<$ty> for HighPrecision<$ty> { + fn sample(&self, rng: &mut R) -> $ty { + fn bitmask(bits: $uty) -> $uty { + if bits >= ::core::mem::size_of::<$uty>() as $uty * 8 { (-1 as $ity) as $uty } else { (1 as $uty << bits) - 1 } } - exp -= additional; - zeros = mantissa.leading_zeros() as u16; - } + loop { + let signed_mant = self.distribution.sample(rng); + // Operate on the absolute value so that we can count bit-sizes + // correctly + let is_neg = signed_mant < 0; + let mut mantissa = signed_mant.abs() as $uty; + + // If the resulting mantissa has too few bits, arithmetically add additional + // bits from rng. When `mant` represents a negative number, this means + // subtracting the generated bits. + const GOAL_ZEROS: u16 = $exponent_bits; // full_size - $fraction_bits - 1 = $exponent_bits + let mut exp = self.exponent; + let mut zeros = mantissa.leading_zeros() as u16; + while zeros > GOAL_ZEROS && exp > 1 { + let additional = ::core::cmp::min(zeros - GOAL_ZEROS, exp - 1); + let additional_bits = rng.gen::<$uty>() >> (::core::mem::size_of::<$uty>() as u16 * 8 - additional); + mantissa <<= additional; + if !is_neg { + mantissa |= additional_bits; + } else { + mantissa -= additional_bits; + } + exp -= additional; + zeros = mantissa.leading_zeros() as u16; + } - // At this point, if we generate and add more bits, we're just - // going to have to round them off. Since we round towards negative - // infinity, i.e. the opposite direction of the added bits, we'll - // just get back to exactly where we are now. - - // There is an edgecase if the mantissa is negative 0x0010_0000_0000_0000. - // While this number is already 53 bits, if we subtract more - // geneated bits from this number, we will lose one bit at the top - // and get fewer total bits. That means that we can fit an extra - // bit at the end, which if it's a zero will prevent rounding from - // getting us back to the original value. - if mantissa == (1u64 << 52) && is_neg && exp > 1 && rng.gen::() { - mantissa = bitmask(53); - exp -= 1; - } + // At this point, if we generate and add more bits, we're just + // going to have to round them off. Since we round towards negative + // infinity, i.e. the opposite direction of the added bits, we'll + // just get back to exactly where we are now. + + // There is an edgecase if the mantissa is negative 0x0010_0000_0000_0000. + // While this number is already 53 bits, if we subtract more + // geneated bits from this number, we will lose one bit at the top + // and get fewer total bits. That means that we can fit an extra + // bit at the end, which if it's a zero will prevent rounding from + // getting us back to the original value. + if mantissa == (1 as $uty << $fraction_bits) && is_neg && exp > 1 && rng.gen::() { + mantissa = bitmask($fraction_bits + 1); + exp -= 1; + } - // Handle underflow values - if mantissa < (1u64 << 52) { - debug_assert_eq!(exp, 1); - exp = 0; - } + // Handle underflow values + if mantissa < (1 as $uty << $fraction_bits) { + debug_assert_eq!(exp, 1); + exp = 0; + } - // Merge exponent and mantissa into final result - let mut res = (mantissa & bitmask(52)) | - ((exp as u64) << 52); - let mut res_as_int = res as i64; - if is_neg { - res_as_int *= -1; - res |= 1u64 << 63; - } + // Merge exponent and mantissa into final result + let mut res = (mantissa & bitmask($fraction_bits)) | + ((exp as $uty) << $fraction_bits); + let mut res_as_int = res as $ity; + if is_neg { + res_as_int *= -1; + res |= 1 as $uty << ($fraction_bits + $exponent_bits); + } - // Check that we're within the requested bounds. These checks can - // only fail on the side that was shifted and rounded during - // initial parsing. - if self.low_as_int <= res_as_int && res_as_int < self.high_as_int { - return f64::from_bits(res); - } + // Check that we're within the requested bounds. These checks can + // only fail on the side that was shifted and rounded during + // initial parsing. + if self.low_as_int <= res_as_int && res_as_int < self.high_as_int { + return <$ty>::from_bits(res); + } - // Assert that we got here due to rounding - #[cfg(debug_assertions)] - { - let exp_low = (self.low_as_int.abs() as u64 >> 52) & bitmask(11); - let exp_high = (self.high_as_int.abs() as u64 >> 52) & bitmask(11); - let mant_low = self.low_as_int.abs() as u64 & bitmask(52); - let mant_high = self.high_as_int.abs() as u64 & bitmask(52); - if res_as_int < self.low_as_int { - assert!(exp_low < exp_high); - assert!(mant_low & bitmask(exp_high - exp_low) != 0); - } - if res_as_int >= self.high_as_int { - assert!(exp_high < exp_low); - assert!(mant_high & bitmask(exp_low - exp_high) != 0); + // Assert that we got here due to rounding + #[cfg(debug_assertions)] + { + let exp_low = (self.low_as_int.abs() as $uty >> $fraction_bits) & bitmask($exponent_bits); + let exp_high = (self.high_as_int.abs() as $uty >> $fraction_bits) & bitmask($exponent_bits); + let mant_low = self.low_as_int.abs() as $uty & bitmask($fraction_bits); + let mant_high = self.high_as_int.abs() as $uty & bitmask($fraction_bits); + if res_as_int < self.low_as_int { + assert!(exp_low < exp_high); + assert!(mant_low & bitmask(exp_high - exp_low) != 0); + } + if res_as_int >= self.high_as_int { + assert!(exp_high < exp_low); + assert!(mant_high & bitmask(exp_low - exp_high) != 0); + } + } + + // If not start over. We're avoiding reusing any of the previous + // computation in order to avoid introducing bias, and to keep + // things simple since this should be rare. } } - - // If not start over. We're avoiding reusing any of the previous - // computation in order to avoid introducing bias, and to keep - // things simple since this should be rare. } + + } } +float_impls! { f32, u32, i32, 23, 8, 127 } +float_impls! { f64, u64, i64, 52, 11, 1023 } #[cfg(test)] mod tests { @@ -372,95 +395,119 @@ mod tests { #[cfg(feature = "alloc")] #[test] fn test_highprecision() { - fn to_signed_bits(val: f64) -> i64 { - if val >= 0.0 { - val.to_bits() as i64 - } else { - -((-val).to_bits() as i64) - } - } - fn from_signed_bits(val: i64) -> f64 { - if val >= 0 { - f64::from_bits(val as u64) - } else { - -f64::from_bits(-val as u64) - } - } let mut r = ::test::rng(601); - let mut vals: Vec = - [0i64, - 0x0000_0f00_0000_0000, - 0x0001_0000_0000_0000, - 0x0004_0000_0000_0000, - 0x0008_0000_0000_0000, - 0x0010_0000_0000_0000, - 0x0020_0000_0000_0000, - 0x0040_0000_0000_0000, - 0x0100_0000_0000_0000, - 0x00cd_ef12_3456_789a, - 0x0100_ffff_ffff_ffff, - 0x010f_ffff_ffff_ffff, - 0x0400_1234_5678_abcd, - 0x7fef_ffff_ffff_ffff, - ].iter().cloned() - .flat_map(|x| (-2i64..3i64).map(move |y| x + y)) - .map(|x| f64::from_bits(x as u64)) - .flat_map(|x| vec![x, -x].into_iter()) - .filter(|x| x.is_finite()) - .collect(); - vals.sort_by(|a, b| a.partial_cmp(b).unwrap()); - vals.dedup(); - - for a in vals.iter().cloned() { - for b in vals.iter().cloned().filter(|&b| b > a) { - let hp = HighPrecision::new(a, b); - let a_bits = to_signed_bits(a); - let b_bits = to_signed_bits(b); - - // If a and b are "close enough", we can verify the full distribution - if (b_bits.wrapping_sub(a_bits) as u64) < 100 { - let mut counts = Vec::::with_capacity((b_bits - a_bits) as usize); - counts.resize((b_bits - a_bits) as usize, 0); - for _ in 0..1000 { - let res = hp.sample(&mut r); - counts[(to_signed_bits(res) - a_bits) as usize] += 1; - } - for (count, i) in counts.iter().zip(0i64..) { - let expected = 1000.0f64 * - (from_signed_bits(a_bits + i + 1) - - from_signed_bits(a_bits + i)) / (b - a); - let err = (*count as f64 - expected) / expected; - assert!(err.abs() <= 0.2); - } - } else { - // Rough estimate that the distribution is correct - let step = if (b - a).is_finite() { - (b - a) / 10.0 - } else { - b / 10.0 - a / 10.0 - }; - assert!(step.is_finite()); - let mut counts = Vec::::with_capacity(10); - counts.resize(10, 0); - for _ in 0..3000 { - let res = hp.sample(&mut r); - assert!(a <= res && res < b); - let index = if (res - a).is_finite() { - (res - a) / step + macro_rules! float_test { + ($ty:ty, $uty:ty, $ity:ty, $test_vals:expr) => {{ + let mut vals: Vec<$ty> = + $test_vals.iter().cloned() + .flat_map(|x| (-2 as $ity..3).map(move |y| x + y)) + .map(|x| <$ty>::from_bits(x as $uty)) + .flat_map(|x| vec![x, -x].into_iter()) + .filter(|x| x.is_finite()) + .collect(); + vals.sort_by(|a, b| a.partial_cmp(b).unwrap()); + vals.dedup(); + + for a in vals.iter().cloned() { + for b in vals.iter().cloned().filter(|&b| b > a) { + fn to_signed_bits(val: $ty) -> $ity { + if val >= 0.0 { + val.to_bits() as $ity + } else { + -((-val).to_bits() as $ity) + } + } + fn from_signed_bits(val: $ity) -> $ty { + if val >= 0 { + <$ty>::from_bits(val as $uty) + } else { + -<$ty>::from_bits(-val as $uty) + } + } + + let hp = HighPrecision::new(a, b); + let a_bits = to_signed_bits(a); + let b_bits = to_signed_bits(b); + + // If a and b are "close enough", we can verify the full distribution + if (b_bits.wrapping_sub(a_bits) as $uty) < 100 { + let mut counts = Vec::::with_capacity((b_bits - a_bits) as usize); + counts.resize((b_bits - a_bits) as usize, 0); + for _ in 0..1000 { + let res = hp.sample(&mut r); + counts[(to_signed_bits(res) - a_bits) as usize] += 1; + } + for (count, i) in counts.iter().zip(0 as $ity..) { + let expected = 1000.0 as $ty * + (from_signed_bits(a_bits + i + 1) - + from_signed_bits(a_bits + i)) / (b - a); + let err = (*count as $ty - expected) / expected; + assert!(err.abs() <= 0.2); + } } else { - res / step - a / step - } as usize; - counts[index] += 1; - } - for count in &counts { - let expected = 3000.0f64 / 10.0; - let err = (*count as f64 - expected) / expected; - assert!(err.abs() <= 0.25); + // Rough estimate that the distribution is correct + let step = if (b - a).is_finite() { + (b - a) / 10.0 + } else { + b / 10.0 - a / 10.0 + }; + assert!(step.is_finite()); + let mut counts = Vec::::with_capacity(10); + counts.resize(10, 0); + for _ in 0..3000 { + let res = hp.sample(&mut r); + assert!(a <= res && res < b); + let index = if (res - a).is_finite() { + (res - a) / step + } else { + res / step - a / step + } as usize; + counts[::core::cmp::min(index, 9)] += 1; + } + for count in &counts { + let expected = 3000.0 as $ty / 10.0; + let err = (*count as $ty - expected) / expected; + assert!(err.abs() <= 0.25); + } + } } } - } + }} } + + float_test!(f64, u64, i64, + [0i64, + 0x0000_0f00_0000_0000, + 0x0001_0000_0000_0000, + 0x0004_0000_0000_0000, + 0x0008_0000_0000_0000, + 0x0010_0000_0000_0000, + 0x0020_0000_0000_0000, + 0x0040_0000_0000_0000, + 0x0100_0000_0000_0000, + 0x00cd_ef12_3456_789a, + 0x0100_ffff_ffff_ffff, + 0x010f_ffff_ffff_ffff, + 0x0400_1234_5678_abcd, + 0x7fef_ffff_ffff_ffff, + ]); + float_test!(f32, u32, i32, + [0i32, + 0x000f_0000, + 0x0008_0000, + 0x0020_0000, + 0x0040_0000, + 0x0080_0000, + 0x0100_0000, + 0x0200_0000, + 0x0800_0000, + 0x5678_abcd, + 0x0807_ffff, + 0x087f_ffff, + 0x4012_3456, + 0x7f7f_ffff, + ]); } }