From dd38432c3317c040688464fb914040a5d302a123 Mon Sep 17 00:00:00 2001
From: Paul Dicker <pitdicker@gmail.com>
Date: Fri, 8 Jun 2018 20:27:03 +0200
Subject: [PATCH] WIP

---
 src/seq.rs | 108 ++++++++++++++++++++++++++++++-----------------------
 1 file changed, 61 insertions(+), 47 deletions(-)

diff --git a/src/seq.rs b/src/seq.rs
index 2bf66d1fd7c..15aba93541b 100644
--- a/src/seq.rs
+++ b/src/seq.rs
@@ -160,7 +160,7 @@ pub fn sample_slice_ref<'a, R, T>(rng: &mut R, slice: &'a [T], amount: usize,
 ///
 /// Panics if `amount > length` or if `length` is not reprentable as a `u32`.
 pub fn sample_indices<R>(rng: &mut R, length: usize, amount: usize,
-    shuffled: bool) -> Vec<u32>
+    shuffled: bool) -> Vec<usize>
     where R: Rng + ?Sized,
 {
     if amount > length {
@@ -169,9 +169,7 @@ pub fn sample_indices<R>(rng: &mut R, length: usize, amount: usize,
     if length > (::core::u32::MAX as usize) {
         panic!("`length` is not representable as `u32`");
     }
-    let amount = amount as u32;
-    let length = length as u32;
-    
+
     // Choice of algorithm here depends on both length and amount. See:
     // https://github.com/rust-lang-nursery/rand/pull/479
     // We do some calculations with u64 to avoid overflow.
@@ -183,7 +181,9 @@ pub fn sample_indices<R>(rng: &mut R, length: usize, amount: usize,
         if C[0][j] * (length as u64) < (C[1][j] + m4) * amount as u64 {
             sample_indices_inplace(rng, length, amount)
         } else if shuffled {
-            sample_indices_floyd_shuffled(rng, length, amount)
+            let mut indices = sample_indices_floyd(rng, length, amount);
+            rng.shuffle(&mut indices);
+            indices
         } else {
             sample_indices_floyd(rng, length, amount)
         }
@@ -205,18 +205,32 @@ pub fn sample_indices<R>(rng: &mut R, length: usize, amount: usize,
 /// sampled elements).
 ///
 /// This implementation uses `O(amount)` memory and `O(amount^2)` time.
-fn sample_indices_floyd<R>(rng: &mut R, length: u32, amount: u32) -> Vec<u32>
+fn sample_indices_floyd<R>(rng: &mut R, length: usize, amount: usize)
+    -> Vec<usize>
     where R: Rng + ?Sized,
 {
     debug_assert!(amount <= length);
-    let mut indices = Vec::with_capacity(amount as usize);
-    for j in length - amount .. length {
-        let t = rng.gen_range(0, j + 1);
-        if indices.contains(&t) {
-            indices.push(j)
-        } else {
-            indices.push(t)
-        };
+    let mut indices = Vec::with_capacity(amount);
+    if length <= ::core::u32::MAX as usize {
+        let length = length as u32;
+        let amount = amount as u32;
+        for j in length - amount .. length {
+            let t = rng.gen_range(0, j + 1) as usize;
+            if indices.contains(&t) {
+                indices.push(j as usize)
+            } else {
+                indices.push(t)
+            };
+        }
+    } else {
+        for j in length - amount .. length {
+            let t = rng.gen_range(0, j + 1);
+            if indices.contains(&t) {
+                indices.push(j)
+            } else {
+                indices.push(t)
+            };
+        }
     }
     indices
 }
@@ -228,8 +242,8 @@ fn sample_indices_floyd<R>(rng: &mut R, length: u32, amount: u32) -> Vec<u32>
 /// more than double the time since our implementation is already `O(amount^2)`.
 ///
 /// This implementation uses `O(amount)` memory and `O(amount^2)` time.
-fn sample_indices_floyd_shuffled<R>(rng: &mut R, length: u32, amount: u32)
-    -> Vec<u32>
+fn sample_indices_floyd_shuffled<R>(rng: &mut R, length: usize, amount: usize)
+    -> Vec<usize>
     where R: Rng + ?Sized,
 {
     debug_assert!(amount <= length);
@@ -265,19 +279,19 @@ fn sample_indices_floyd_shuffled<R>(rng: &mut R, length: u32, amount: u32)
 /// This is likely the fastest for small lengths since it avoids the need for
 /// allocations. Set-up is `O(length)` time and memory and shuffling is
 /// `O(amount)` time.
-fn sample_indices_inplace<R>(rng: &mut R, length: u32, amount: u32)
-    -> Vec<u32>
+fn sample_indices_inplace<R>(rng: &mut R, length: usize, amount: usize)
+    -> Vec<usize>
     where R: Rng + ?Sized,
 {
     debug_assert!(amount <= length);
-    let mut indices: Vec<u32> = Vec::with_capacity(length as usize);
+    let mut indices: Vec<usize> = Vec::with_capacity(length);
     indices.extend(0..length);
     for i in 0..amount {
-        let j: u32 = rng.gen_range(i, length);
-        indices.swap(i as usize, j as usize);
+        let j = rng.gen_range(i, length);
+        indices.swap(i, j);
     }
-    indices.truncate(amount as usize);
-    debug_assert_eq!(indices.len(), amount as usize);
+    indices.truncate(amount);
+    debug_assert_eq!(indices.len(), amount);
     indices
 }
 
@@ -288,16 +302,16 @@ fn sample_indices_inplace<R>(rng: &mut R, length: u32, amount: u32)
 /// especially useful when `amount <<< length`; e.g. selecting 3 non-repeating
 /// values from `1_000_000`. The algorithm is `O(amount)` time and memory,
 /// but due to overheads will often be slower than other approaches.
-fn sample_indices_cache<R>(rng: &mut R, length: u32, amount: u32)
-    -> Vec<u32>
+fn sample_indices_cache<R>(rng: &mut R, length: usize, amount: usize)
+    -> Vec<usize>
     where R: Rng + ?Sized,
 {
     debug_assert!(amount <= length);
-    #[cfg(feature="std")] let mut cache = HashMap::with_capacity(amount as usize);
+    #[cfg(feature="std")] let mut cache = HashMap::with_capacity(amount);
     #[cfg(not(feature="std"))] let mut cache = BTreeMap::new();
-    let mut indices = Vec::with_capacity(amount as usize);
+    let mut indices = Vec::with_capacity(amount);
     for i in 0..amount {
-        let j: u32 = rng.gen_range(i, length);
+        let j = rng.gen_range(i, length);
 
         // get the current values at i and j ...
         let x_i = match cache.get(&i) {
@@ -313,7 +327,7 @@ fn sample_indices_cache<R>(rng: &mut R, length: u32, amount: u32)
         cache.insert(j, x_i);
         indices.push(x_j);  // push at position i
     }
-    debug_assert_eq!(indices.len(), amount as usize);
+    debug_assert_eq!(indices.len(), amount);
     indices
 }
 
@@ -370,19 +384,19 @@ mod test {
         assert_eq!(&sample_indices_cache(&mut r, 1, 0)[..], [0; 0]);
         assert_eq!(&sample_indices_cache(&mut r, 1, 1)[..], [0]);
 
-        assert_eq!(&sample_indices_floyd(&mut r, 0, 0, false)[..], [0; 0]);
-        assert_eq!(&sample_indices_floyd(&mut r, 1, 0, false)[..], [0; 0]);
-        assert_eq!(&sample_indices_floyd(&mut r, 1, 1, false)[..], [0]);
-        assert_eq!(&sample_indices_floyd(&mut r, 0, 0, true)[..], [0; 0]);
-        assert_eq!(&sample_indices_floyd(&mut r, 1, 0, true)[..], [0; 0]);
-        assert_eq!(&sample_indices_floyd(&mut r, 1, 1, true)[..], [0]);
+        assert_eq!(&sample_indices_floyd(&mut r, 0, 0)[..], [0; 0]);
+        assert_eq!(&sample_indices_floyd(&mut r, 1, 0)[..], [0; 0]);
+        assert_eq!(&sample_indices_floyd(&mut r, 1, 1)[..], [0]);
+        assert_eq!(&sample_indices_floyd_shuffled(&mut r, 0, 0)[..], [0; 0]);
+        assert_eq!(&sample_indices_floyd_shuffled(&mut r, 1, 0)[..], [0; 0]);
+        assert_eq!(&sample_indices_floyd_shuffled(&mut r, 1, 1)[..], [0]);
         
         // These algorithms should be fast with big numbers. Test average.
         let sum = sample_indices_cache(&mut r, 1 << 25, 10)
             .iter().fold(0, |a, b| a + b);
         assert!(1 << 25 < sum && sum < (1 << 25) * 25);
         
-        let sum = sample_indices_floyd(&mut r, 1 << 25, 10, false)
+        let sum = sample_indices_floyd(&mut r, 1 << 25, 10)
             .iter().fold(0, |a, b| a + b);
         assert!(1 << 25 < sum && sum < (1 << 25) * 25);
 
@@ -419,10 +433,10 @@ mod test {
             let regular = sample_indices(
                 &mut xor_rng(seed), length, amount, false);
             assert_eq!(regular.len(), amount);
-            assert!(regular.iter().all(|e| *e < length as u32));
+            assert!(regular.iter().all(|e| *e < length));
 
             // also test that sampling the slice works
-            let vec: Vec<u32> = (0..(length as u32)).collect();
+            let vec: Vec<usize> = (0..length).collect();
             let result = sample_slice(&mut xor_rng(seed), &vec, amount, false);
             assert_eq!(result, regular);
 
@@ -444,14 +458,14 @@ mod test {
         
         // A small length and relatively large amount should use inplace
         r.fill(&mut seed);
-        let (length, amount): (u32, u32) = (100, 50);
-        let v1 = sample_indices(&mut xor_rng(seed), length as usize, amount as usize, true);
+        let (length, amount): (usize, usize) = (100, 50);
+        let v1 = sample_indices(&mut xor_rng(seed), length, amount, true);
         let v2 = sample_indices_inplace(&mut xor_rng(seed), length, amount);
         assert!(v1.iter().all(|e| *e < length));
         assert_eq!(v1, v2);
         
         // Test Floyd's alg does produce different results
-        let v3 = sample_indices_floyd(&mut xor_rng(seed), length, amount, true);
+        let v3 = sample_indices_floyd_shuffled(&mut xor_rng(seed), length, amount);
         assert!(v1 != v3);
         // However, the cache alg should produce the same results
         let v4 = sample_indices_cache(&mut xor_rng(seed), length, amount);
@@ -459,18 +473,18 @@ mod test {
         
         // A large length and small amount should use Floyd
         r.fill(&mut seed);
-        let (length, amount): (u32, u32) = (1<<20, 50);
-        let v1 = sample_indices(&mut xor_rng(seed), length as usize, amount as usize, true);
-        let v2 = sample_indices_floyd(&mut xor_rng(seed), length, amount, true);
+        let (length, amount): (usize, usize) = (1<<20, 50);
+        let v1 = sample_indices(&mut xor_rng(seed), length, amount, true);
+        let v2 = sample_indices_floyd_shuffled(&mut xor_rng(seed), length, amount);
         assert!(v1.iter().all(|e| *e < length));
         assert_eq!(v1, v2);
         
         // A large length and larger amount should use cache
         r.fill(&mut seed);
-        let (length, amount): (u32, u32) = (1<<20, 600);
-        let v1 = sample_indices(&mut xor_rng(seed), length as usize, amount as usize, true);
+        let (length, amount): (usize, usize) = (1<<20, 600);
+        let v1 = sample_indices(&mut xor_rng(seed), length, amount, true);
         let v2 = sample_indices_cache(&mut xor_rng(seed), length, amount);
-        assert!(v1.iter().all(|e| *e < length as u32));
+        assert!(v1.iter().all(|e| *e < length));
         assert_eq!(v1, v2);
     }
 }