Skip to content

Commit

Permalink
Merge pull request #171 from haskell/lehins/use-array-for-shuffle
Browse files Browse the repository at this point in the history
Implement a faster and unbiased version of list shuffling
  • Loading branch information
lehins authored Dec 26, 2024
2 parents 6b30bd9 + a79f427 commit fb9dcb1
Show file tree
Hide file tree
Showing 7 changed files with 223 additions and 37 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
* Add `Seed`, `SeedGen`, `seedSize`, `mkSeed` and `unSeed`:
[#162](https://github.com/haskell/random/pull/162)
* Add `SplitGen` and `splitGen`: [#160](https://github.com/haskell/random/pull/160)
* Add `shuffleList` and `shuffleListM`: [#140](https://github.com/haskell/random/pull/140)
* Add `unifromShuffleList` and `unifromShuffleListM`: [#140](https://github.com/haskell/random/pull/140)
* Add `uniformWordR`: [#140](https://github.com/haskell/random/pull/140)
* Add `mkStdGen64`: [#155](https://github.com/haskell/random/pull/155)
* Add `uniformListRM`, `uniformList`, `uniformListR`, `uniforms` and `uniformRs`:
[#154](https://github.com/haskell/random/pull/154)
Expand Down
18 changes: 17 additions & 1 deletion bench/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ module Main (main) where
import Control.Monad
import Control.Monad.State.Strict
import Data.Int
import Data.List (sortOn)
import Data.Proxy
import Data.Typeable
import Data.Word
Expand Down Expand Up @@ -263,9 +264,15 @@ main = do
, env getStdGen $ \gen ->
bench "uniformByteArray 100MB" $ nf (\n -> uniformByteArray False n gen) sz100MiB
, env getStdGen $ \gen ->
bench "genByteString 100MB" $ nf (\k -> genByteString k gen) sz100MiB
bench "genByteString 100MB" $ nf (`genByteString` gen) sz100MiB
]
]
, env (pure [0 :: Integer .. 200000]) $ \xs ->
bgroup "shuffle"
[ env getStdGen $ bench "uniformShuffleList" . nf (uniformShuffleList xs)
, env getStdGen $ bench "uniformShuffleListM" . nf (`runStateGen` uniformShuffleListM xs)
, env getStdGen $ bench "naiveShuffleListM" . nf (`runStateGen` naiveShuffleListM xs)
]
]

pureUniformRFullBench ::
Expand Down Expand Up @@ -351,3 +358,12 @@ fillMutablePrimArrayM f ma g = do
go 0
unsafeFreezePrimArray ma
#endif


naiveShuffleListM :: StatefulGen g m => [a] -> g -> m [a]
naiveShuffleListM xs gen = do
is <- uniformListM n gen
pure $ map snd $ sortOn fst $ zip (is :: [Int]) xs
where
!n = length xs
{-# INLINE naiveShuffleListM #-}
17 changes: 9 additions & 8 deletions src/System/Random.hs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ module System.Random
, uniformRs
, uniformList
, uniformListR
, shuffleList
, uniformShuffleList
-- ** Bytes
, uniformByteArray
, uniformByteString
Expand Down Expand Up @@ -94,6 +94,7 @@ import Data.IORef
import Data.Word
import Foreign.C.Types
import GHC.Exts
import System.Random.Array (shuffleListST)
import System.Random.GFinite (Finite)
import System.Random.Internal
import System.Random.Seed
Expand Down Expand Up @@ -294,18 +295,18 @@ uniformListR :: (UniformRange a, RandomGen g) => Int -> (a, a) -> g -> ([a], g)
uniformListR n r g = runStateGen g (uniformListRM n r)
{-# INLINE uniformListR #-}

-- | Shuffle elements of a list in a random order.
-- | Shuffle elements of a list in a uniformly random order.
--
-- ====__Examples__
--
-- >>> let gen = mkStdGen 2023
-- >>> shuffleList ['a'..'z'] gen
-- ("renlhfqmgptwksdiyavbxojzcu",StdGen {unStdGen = SMGen 9882508430712573120 1920468677557965761})
-- >>> uniformShuffleList "ELVIS" $ mkStdGen 252
-- ("LIVES",StdGen {unStdGen = SMGen 17676540583805057877 5302934877338729551})
--
-- @since 1.3.0
shuffleList :: RandomGen g => [a] -> g -> ([a], g)
shuffleList xs g = runStateGen g (shuffleListM xs)
{-# INLINE shuffleList #-}
uniformShuffleList :: RandomGen g => [a] -> g -> ([a], g)
uniformShuffleList xs g =
runStateGenST g $ \gen -> shuffleListST (`uniformWordR` gen) xs
{-# INLINE uniformShuffleList #-}

-- | Generates a 'ByteString' of the specified size using a pure pseudo-random
-- number generator. See 'uniformByteStringM' for the monadic version.
Expand Down
156 changes: 156 additions & 0 deletions src/System/Random/Array.hs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,17 @@ module System.Random.Array
, byteArrayToShortByteString
, getSizeOfMutableByteArray
, shortByteStringToByteString
-- ** MutableArray
, Array (..)
, MutableArray (..)
, newMutableArray
, freezeMutableArray
, writeArray
, shuffleListM
, shuffleListST
) where

import Control.Monad.Trans (lift, MonadTrans)
import Control.Monad (when)
import Control.Monad.ST
import Data.Array.Byte (ByteArray(..), MutableByteArray(..))
Expand All @@ -54,6 +63,10 @@ import Data.ByteString (ByteString)
wordSizeInBits :: Int
wordSizeInBits = finiteBitSize (0 :: Word)

----------------
-- Byte Array --
----------------

-- Architecture independent helpers:

sizeOfByteArray :: ByteArray -> Int
Expand Down Expand Up @@ -204,3 +217,146 @@ pinnedByteArrayToForeignPtr ba# =
ForeignPtr (byteArrayContents# ba#) (PlainPtr (unsafeCoerce# ba#))
{-# INLINE pinnedByteArrayToForeignPtr #-}
#endif

-----------------
-- Boxed Array --
-----------------

data Array a = Array (Array# a)

data MutableArray s a = MutableArray (MutableArray# s a)

newMutableArray :: Int -> a -> ST s (MutableArray s a)
newMutableArray (I# n#) a =
ST $ \s# ->
case newArray# n# a s# of
(# s'#, ma# #) -> (# s'#, MutableArray ma# #)
{-# INLINE newMutableArray #-}

freezeMutableArray :: MutableArray s a -> ST s (Array a)
freezeMutableArray (MutableArray ma#) =
ST $ \s# ->
case unsafeFreezeArray# ma# s# of
(# s'#, a# #) -> (# s'#, Array a# #)
{-# INLINE freezeMutableArray #-}

sizeOfMutableArray :: MutableArray s a -> Int
sizeOfMutableArray (MutableArray ma#) = I# (sizeofMutableArray# ma#)
{-# INLINE sizeOfMutableArray #-}

readArray :: MutableArray s a -> Int -> ST s a
readArray (MutableArray ma#) (I# i#) = ST (readArray# ma# i#)
{-# INLINE readArray #-}

writeArray :: MutableArray s a -> Int -> a -> ST s ()
writeArray (MutableArray ma#) (I# i#) a = st_ (writeArray# ma# i# a)
{-# INLINE writeArray #-}

swapArray :: MutableArray s a -> Int -> Int -> ST s ()
swapArray ma i j = do
x <- readArray ma i
y <- readArray ma j
writeArray ma j x
writeArray ma i y
{-# INLINE swapArray #-}

-- | Write contents of the list into the mutable array. Make sure that array is big
-- enough or segfault will happen.
fillMutableArrayFromList :: MutableArray s a -> [a] -> ST s ()
fillMutableArrayFromList ma = go 0
where
go _ [] = pure ()
go i (x:xs) = writeArray ma i x >> go (i + 1) xs
{-# INLINE fillMutableArrayFromList #-}

readListFromMutableArray :: MutableArray s a -> ST s [a]
readListFromMutableArray ma = go (len - 1) []
where
len = sizeOfMutableArray ma
go i !acc
| i >= 0 = do
x <- readArray ma i
go (i - 1) (x : acc)
| otherwise = pure acc
{-# INLINE readListFromMutableArray #-}


-- | Generate a list of indices that will be used for swapping elements in uniform shuffling:
--
-- @
-- [ (0, n - 1)
-- , (0, n - 2)
-- , (0, n - 3)
-- , ...
-- , (0, 3)
-- , (0, 2)
-- , (0, 1)
-- ]
-- @
genSwapIndices
:: Monad m
=> (Word -> m Word)
-- ^ Action that generates a Word in the supplied range.
-> Word
-- ^ Number of index swaps to generate.
-> m [Int]
genSwapIndices genWordR n = go 1 []
where
go i !acc
| i >= n = pure acc
| otherwise = do
x <- genWordR i
let !xi = fromIntegral x
go (i + 1) (xi : acc)
{-# INLINE genSwapIndices #-}


-- | Implementation of mutable version of Fisher-Yates shuffle. Unfortunately, we cannot generally
-- interleave pseudo-random number generation and mutation of `ST` monad, therefore we have to
-- pre-generate all of the index swaps with `genSwapIndices` and store them in a list before we can
-- perform the actual swaps.
shuffleListM :: Monad m => (Word -> m Word) -> [a] -> m [a]
shuffleListM genWordR ls
| len <= 1 = pure ls
| otherwise = do
swapIxs <- genSwapIndices genWordR (fromIntegral len)
pure $ runST $ do
ma <- newMutableArray len $ error "Impossible: shuffleListM"
fillMutableArrayFromList ma ls

-- Shuffle elements of the mutable array according to the uniformly generated index swap list
let goSwap _ [] = pure ()
goSwap i (j:js) = swapArray ma i j >> goSwap (i - 1) js
goSwap (len - 1) swapIxs

readListFromMutableArray ma
where
len = length ls
{-# INLINE shuffleListM #-}

-- | This is a ~x2-x3 more efficient version of `shuffleListM`. It is more efficient because it does
-- not need to pregenerate a list of indices and instead generates them on demand. Because of this the
-- result that will be produced will differ for the same generator, since the order in which index
-- swaps are generated is reversed.
--
-- Unfortunately, most stateful generator monads can't handle `MonadTrans`, so this version is only
-- used for implementing the pure shuffle.
shuffleListST :: (Monad (t (ST s)), MonadTrans t) => (Word -> t (ST s) Word) -> [a] -> t (ST s) [a]
shuffleListST genWordR ls
| len <= 1 = pure ls
| otherwise = do
ma <- lift $ newMutableArray len $ error "Impossible: shuffleListST"
lift $ fillMutableArrayFromList ma ls

-- Shuffle elements of the mutable array according to the uniformly generated index swap
let goSwap i =
when (i > 0) $ do
j <- genWordR $ (fromIntegral :: Int -> Word) i
lift $ swapArray ma i ((fromIntegral :: Word -> Int) j)
goSwap (i - 1)
goSwap (len - 1)

lift $ readListFromMutableArray ma
where
len = length ls
{-# INLINE shuffleListST #-}
44 changes: 20 additions & 24 deletions src/System/Random/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ module System.Random.Internal
, Uniform(..)
, uniformViaFiniteM
, UniformRange(..)
, uniformWordR
, uniformDouble01M
, uniformDoublePositive01M
, uniformFloat01M
Expand All @@ -65,7 +66,6 @@ module System.Random.Internal
, uniformEnumRM
, uniformListM
, uniformListRM
, shuffleListM
, isInRangeOrd
, isInRangeEnum

Expand Down Expand Up @@ -108,7 +108,6 @@ import Data.ByteString (ByteString)
import Data.ByteString.Short.Internal (ShortByteString(SBS))
import Data.IORef (IORef, newIORef)
import Data.Int
import Data.List (sortOn)
import Data.Word
import Foreign.C.Types
import Foreign.Storable (Storable)
Expand Down Expand Up @@ -221,7 +220,6 @@ class RandomGen g where
-- /Note/ - This function will be removed from the type class in the next major release as
-- it is no longer needed because of `unsafeUniformFillMutableByteArray`.
--
--
-- @since 1.2.0
genShortByteString :: Int -> g -> (ShortByteString, g)
genShortByteString n g =
Expand Down Expand Up @@ -273,10 +271,10 @@ class RandomGen g where
{-# DEPRECATED split "In favor of `splitGen`" #-}

-- | Pseudo-random generators that can be split into two separate and independent
-- psuedo-random generators can have an instance for this type class.
-- psuedo-random generators should provide an instance for this type class.
--
-- Historically this functionality was included in the `RandomGen` type class in the
-- `split` function, however, few pseudo-random generators posses this property of
-- `split` function, however, few pseudo-random generators possess this property of
-- splittability. This lead the old `split` function being usually implemented in terms of
-- `error`.
--
Expand Down Expand Up @@ -784,25 +782,6 @@ uniformListRM :: (StatefulGen g m, UniformRange a) => Int -> (a, a) -> g -> m [a
uniformListRM n range gen = replicateM n (uniformRM range gen)
{-# INLINE uniformListRM #-}

-- | Shuffle elements of a list in a random order.
--
-- ====__Examples__
--
-- >>> import System.Random.Stateful
-- >>> let pureGen = mkStdGen 2023
-- >>> g <- newIOGenM pureGen
-- >>> shuffleListM ['a'..'z'] g :: IO String
-- "renlhfqmgptwksdiyavbxojzcu"
--
-- @since 1.3.0
shuffleListM :: StatefulGen g m => [a] -> g -> m [a]
shuffleListM xs gen = do
is <- uniformListM n gen
pure $ map snd $ sortOn fst $ zip (is :: [Int]) xs
where
!n = length xs
{-# INLINE shuffleListM #-}

-- | The standard pseudo-random number generator.
newtype StdGen = StdGen { unStdGen :: SM.SMGen }
deriving (Show, RandomGen, SplitGen, NFData)
Expand Down Expand Up @@ -1128,6 +1107,23 @@ instance UniformRange Word where
{-# INLINE uniformRM #-}
isInRange = isInRangeOrd

-- | Architecture specific `Word` generation in the specified lower range
--
-- @since 1.3.0
uniformWordR ::
StatefulGen g m
=> Word
-- ^ Maximum value to generate
-> g
-- ^ Stateful generator
-> m Word
uniformWordR r
| wordSizeInBits == 64 =
fmap (fromIntegral :: Word64 -> Word) . uniformWord64R ((fromIntegral :: Word -> Word64) r)
| otherwise =
fmap (fromIntegral :: Word32 -> Word) . uniformWord32R ((fromIntegral :: Word -> Word32) r)
{-# INLINE uniformWordR #-}

instance Uniform Word8 where
uniformM = uniformWord8
{-# INLINE uniformM #-}
Expand Down
Loading

0 comments on commit fb9dcb1

Please sign in to comment.