Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added serde1 feature to Serialize/Deserialize WeightedIndex #974

Merged
merged 14 commits into from
May 18, 2020
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ A [separate changelog is kept for rand_core](rand_core/CHANGELOG.md).

You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.html) useful.

## [Unreleased]
### Additions
- Added a `serde1` feature and added Serialize/Deserialize to `UniformInt` and `WeightedIndex` (#974)

## [0.7.3] - 2020-01-10
### Fixes
- The `Bernoulli` distribution constructors now reports an error on NaN and on
Expand Down
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ appveyor = { repository = "rust-random/rand" }
# Meta-features:
default = ["std", "std_rng"]
nightly = ["simd_support"] # enables all features requiring nightly rust
serde1 = [] # does nothing, deprecated
serde1 = ["serde"]
CGMossa marked this conversation as resolved.
Show resolved Hide resolved

# Option (enabled by default): without "std" rand uses libcore; this option
# enables functionality expected to be available on a standard platform.
Expand Down Expand Up @@ -58,6 +58,7 @@ members = [
rand_core = { path = "rand_core", version = "0.5.1" }
rand_pcg = { path = "rand_pcg", version = "0.2", optional = true }
log = { version = "0.4.4", optional = true }
serde = { version = "1.0.103", features = ["derive"], optional = true }

[dependencies.packed_simd]
# NOTE: so far no version works reliably due to dependence on unstable features
Expand All @@ -81,6 +82,8 @@ rand_hc = { path = "rand_hc", version = "0.2", optional = true }
rand_pcg = { path = "rand_pcg", version = "0.2" }
# Only for benches:
rand_hc = { path = "rand_hc", version = "0.2" }
# Only to test serde1
bincode = "1.2.1"

[package.metadata.docs.rs]
all-features = true
12 changes: 12 additions & 0 deletions src/distributions/bernoulli.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use crate::distributions::Distribution;
use crate::Rng;
use core::{fmt, u64};

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};
/// The Bernoulli distribution.
///
/// This is a special case of the Binomial distribution where `n = 1`.
Expand All @@ -32,6 +34,7 @@ use core::{fmt, u64};
/// so only probabilities that are multiples of 2<sup>-64</sup> can be
/// represented.
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Bernoulli {
/// Probability of success, relative to the maximal integer.
p_int: u64,
Expand Down Expand Up @@ -143,6 +146,15 @@ mod test {
use crate::distributions::Distribution;
use crate::Rng;

#[test]
#[cfg(feature="serde1")]
fn test_serializing_deserializing_bernoulli() {
let coin_flip = Bernoulli::new(0.5).unwrap();
let de_coin_flip : Bernoulli = bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap();

assert_eq!(coin_flip.p_int, de_coin_flip.p_int);
}

#[test]
fn test_trivial() {
let mut r = crate::test::rng(1);
Expand Down
5 changes: 5 additions & 0 deletions src/distributions/float.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ use crate::Rng;
use core::mem;
#[cfg(feature = "simd_support")] use packed_simd::*;

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};

/// A distribution to sample floating point numbers uniformly in the half-open
/// interval `(0, 1]`, i.e. including 1 but not 0.
///
Expand All @@ -39,6 +42,7 @@ use core::mem;
/// [`Open01`]: crate::distributions::Open01
/// [`Uniform`]: crate::distributions::uniform::Uniform
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct OpenClosed01;

/// A distribution to sample floating point numbers uniformly in the open
Expand All @@ -65,6 +69,7 @@ pub struct OpenClosed01;
/// [`OpenClosed01`]: crate::distributions::OpenClosed01
/// [`Uniform`]: crate::distributions::uniform::Uniform
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Open01;


Expand Down
4 changes: 4 additions & 0 deletions src/distributions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ pub mod uniform;
#[cfg(feature = "alloc")] pub mod weighted;
#[cfg(feature = "alloc")] mod weighted_index;

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};

mod float;
#[doc(hidden)]
pub mod hidden_export {
Expand Down Expand Up @@ -320,6 +323,7 @@ where
///
/// [`Uniform`]: uniform::Uniform
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Standard;


Expand Down
4 changes: 4 additions & 0 deletions src/distributions/other.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ use core::num::Wrapping;
use crate::distributions::{Distribution, Standard, Uniform};
use crate::Rng;

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};

// ----- Sampling distributions -----

/// Sample a `char`, uniformly distributed over ASCII letters and numbers:
Expand All @@ -34,6 +37,7 @@ use crate::Rng;
/// println!("Random chars: {}", chars);
/// ```
#[derive(Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Alphanumeric;


Expand Down
59 changes: 58 additions & 1 deletion src/distributions/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,11 @@ use crate::Rng;
#[allow(unused_imports)] // rustc doesn't detect that this is actually used
use crate::distributions::utils::Float;


#[cfg(feature = "simd_support")] use packed_simd::*;

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};

/// Sample values uniformly between two bounds.
///
/// [`Uniform::new`] and [`Uniform::new_inclusive`] construct a uniform
Expand Down Expand Up @@ -159,6 +161,7 @@ use crate::distributions::utils::Float;
/// [`new`]: Uniform::new
/// [`new_inclusive`]: Uniform::new_inclusive
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Uniform<X: SampleUniform>(X::Sampler);

impl<X: SampleUniform> Uniform<X> {
Expand Down Expand Up @@ -347,6 +350,7 @@ where Borrowed: SampleUniform
/// multiply by `range`, the result is in the high word. Then comparing the low
/// word against `zone` makes sure our distribution is uniform.
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct UniformInt<X> {
low: X,
range: X,
Expand Down Expand Up @@ -644,6 +648,7 @@ uniform_simd_int_impl! {
/// [`new_inclusive`]: UniformSampler::new_inclusive
/// [`Standard`]: crate::distributions::Standard
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct UniformFloat<X> {
low: X,
scale: X,
Expand Down Expand Up @@ -837,12 +842,14 @@ uniform_float_impl! { f64x8, u64x8, f64, u64, 64 - 52 }
/// Unless you are implementing [`UniformSampler`] for your own types, this type
/// should not be used directly, use [`Uniform`] instead.
#[derive(Clone, Copy, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct UniformDuration {
mode: UniformDurationMode,
offset: u32,
}

#[derive(Debug, Copy, Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
enum UniformDurationMode {
Small {
secs: u64,
Expand Down Expand Up @@ -967,6 +974,56 @@ mod tests {
use super::*;
use crate::rngs::mock::StepRng;

#[test]
#[cfg(feature = "serde1")]
fn test_serialization_uniform_duration() {
let distr = UniformDuration::new(std::time::Duration::from_secs(10), std::time::Duration::from_secs(60));
let de_distr: UniformDuration = bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap();
assert_eq!(
distr.offset, de_distr.offset
);
match (distr.mode, de_distr.mode) {
(UniformDurationMode::Small {secs: a_secs, nanos: a_nanos}, UniformDurationMode::Small {secs, nanos}) => {
assert_eq!(a_secs, secs);

assert_eq!(a_nanos.0.low, nanos.0.low);
assert_eq!(a_nanos.0.range, nanos.0.range);
assert_eq!(a_nanos.0.z, nanos.0.z);
}
(UniformDurationMode::Medium {nanos: a_nanos} , UniformDurationMode::Medium {nanos}) => {
assert_eq!(a_nanos.0.low, nanos.0.low);
assert_eq!(a_nanos.0.range, nanos.0.range);
assert_eq!(a_nanos.0.z, nanos.0.z);
}
(UniformDurationMode::Large {max_secs:a_max_secs, max_nanos:a_max_nanos, secs:a_secs}, UniformDurationMode::Large {max_secs, max_nanos, secs} ) => {
assert_eq!(a_max_secs, max_secs);
assert_eq!(a_max_nanos, max_nanos);

assert_eq!(a_secs.0.low, secs.0.low);
assert_eq!(a_secs.0.range, secs.0.range);
assert_eq!(a_secs.0.z, secs.0.z);
}
_ => panic!("`UniformDurationMode` was not serialized/deserialized correctly")
}
}

#[test]
#[cfg(feature = "serde1")]
fn test_uniform_serialization() {
let unit_box: Uniform<i32> = Uniform::new(-1, 1);
let de_unit_box: Uniform<i32> = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap();

assert_eq!(unit_box.0.low, de_unit_box.0.low);
assert_eq!(unit_box.0.range, de_unit_box.0.range);
assert_eq!(unit_box.0.z, de_unit_box.0.z);

let unit_box: Uniform<f32> = Uniform::new(-1., 1.);
let de_unit_box: Uniform<f32> = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap();

assert_eq!(unit_box.0.low, de_unit_box.0.low);
assert_eq!(unit_box.0.scale, de_unit_box.0.scale);
}

#[should_panic]
#[test]
fn test_uniform_bad_limits_equal_int() {
Expand Down
21 changes: 21 additions & 0 deletions src/distributions/weighted_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ use core::fmt;
// Note that this whole module is only imported if feature="alloc" is enabled.
#[cfg(not(feature = "std"))] use crate::alloc::vec::Vec;

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};

/// A distribution using weighted sampling of discrete items
///
/// Sampling a `WeightedIndex` distribution returns the index of a randomly
Expand Down Expand Up @@ -73,6 +76,7 @@ use core::fmt;
/// [`Uniform<X>`]: crate::distributions::uniform::Uniform
/// [`RngCore`]: crate::RngCore
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct WeightedIndex<X: SampleUniform + PartialOrd> {
cumulative_weights: Vec<X>,
total_weight: X,
Expand Down Expand Up @@ -236,6 +240,23 @@ where X: SampleUniform + PartialOrd
mod test {
use super::*;

#[cfg(feature = "serde1")]
#[test]
fn test_weightedindex_serde1() {
let weighted_index = WeightedIndex::new(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap();

let ser_weighted_index = bincode::serialize(&weighted_index).unwrap();
let de_weighted_index: WeightedIndex<i32> =
bincode::deserialize(&ser_weighted_index).unwrap();

assert_eq!(
de_weighted_index.cumulative_weights,
weighted_index.cumulative_weights
);
assert_eq!(de_weighted_index.total_weight, weighted_index.total_weight);
}


#[test]
#[cfg_attr(miri, ignore)] // Miri is too slow
fn test_weightedindex() {
Expand Down
19 changes: 19 additions & 0 deletions src/rngs/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

use rand_core::{impls, Error, RngCore};

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};

/// A simple implementation of `RngCore` for testing purposes.
///
/// This generates an arithmetic sequence (i.e. adds a constant each step)
Expand All @@ -25,6 +28,7 @@ use rand_core::{impls, Error, RngCore};
/// assert_eq!(sample, [2, 3, 4]);
/// ```
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct StepRng {
v: u64,
a: u64,
Expand Down Expand Up @@ -65,3 +69,18 @@ impl RngCore for StepRng {
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
#[cfg(feature = "serde1")]
fn test_serialization_step_rng() {
let some_rng = StepRng::new(42, 7);
let de_some_rng: StepRng = bincode::deserialize(&bincode::serialize(&some_rng).unwrap()).unwrap();
assert_eq!(some_rng.v, de_some_rng.v);
assert_eq!(some_rng.a, de_some_rng.a);

}
}
21 changes: 21 additions & 0 deletions src/seq/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,14 @@ use crate::alloc::collections::BTreeSet;
use crate::distributions::{uniform::SampleUniform, Distribution, Uniform};
use crate::Rng;

#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};

/// A vector of indices.
///
/// Multiple internal representations are possible.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum IndexVec {
#[doc(hidden)]
U32(Vec<u32>),
Expand Down Expand Up @@ -376,6 +380,23 @@ where
#[cfg(test)]
mod test {
use super::*;

#[test]
#[cfg(feature = "serde1")]
fn test_serialization_index_vec() {
let some_index_vec = IndexVec::from(vec![254_usize, 234, 2, 1]);
let de_some_index_vec: IndexVec = bincode::deserialize(&bincode::serialize(&some_index_vec).unwrap()).unwrap();
match (some_index_vec, de_some_index_vec) {
(IndexVec::U32(a), IndexVec::U32(b)) => {
assert_eq!(a,b);
},
(IndexVec::USize(a), IndexVec::USize(b)) => {
assert_eq!(a,b);
CGMossa marked this conversation as resolved.
Show resolved Hide resolved
},
_ => {panic!("failed to seralize/deserialize `IndexVec`")}
}
}

#[cfg(all(feature = "alloc", not(feature = "std")))] use crate::alloc::vec;
#[cfg(feature = "std")] use std::vec;

Expand Down