Skip to content

Commit

Permalink
Merge pull request #498 from gksato/optimize-nextperm
Browse files Browse the repository at this point in the history
Optimize Mutable.nextPermutation and add {next/prev}permutation(By)
  • Loading branch information
Shimuuar authored Aug 19, 2024
2 parents fd76994 + 7415257 commit eb60526
Show file tree
Hide file tree
Showing 11 changed files with 458 additions and 51 deletions.
122 changes: 122 additions & 0 deletions vector/benchlib/Bench/Vector/Algo/NextPermutation.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
module Bench.Vector.Algo.NextPermutation (generatePermTests) where

import qualified Data.Vector.Unboxed as V
import qualified Data.Vector.Unboxed.Mutable as M
import qualified Data.Vector.Generic.Mutable as G
import System.Random.Stateful
( StatefulGen, UniformRange(uniformRM) )

-- | Generate a list of benchmarks for permutation algorithms.
-- The list contains pairs of benchmark names and corresponding actions.
-- The actions are to be executed by the benchmarking framework.
--
-- The list contains the following benchmarks:
-- - @(next|prev)Permutation@ on a small vector repeated until the end of the permutation cycle
-- - Bijective versions of @(next|prev)Permutation@ on a vector of size @n@, repeated @n@ times
-- - ascending permutation
-- - descending permutation
-- - random permutation
-- - Baseline for bijective versions: just copying a vector of size @n@. Note that the tests for
-- bijective versions begins with copying a vector.
generatePermTests :: StatefulGen g IO => g -> Int -> IO [(String, IO ())]
generatePermTests gen useSize = do
let !k = useSizeToPermLen useSize
let !vasc = V.generate useSize id
!vdesc = V.generate useSize (useSize-1-)
!vrnd <- randomPermutationWith gen useSize
return
[ ("nextPermutation (small vector, until end)", loopPermutations k)
, ("nextPermutationBijective (ascending perm of size n, n times)", repeatNextPermutation vasc useSize)
, ("nextPermutationBijective (descending perm of size n, n times)", repeatNextPermutation vdesc useSize)
, ("nextPermutationBijective (random perm of size n, n times)", repeatNextPermutation vrnd useSize)
, ("prevPermutation (small vector, until end)", loopRevPermutations k)
, ("prevPermutationBijective (ascending perm of size n, n times)", repeatPrevPermutation vasc useSize)
, ("prevPermutationBijective (descending perm of size n, n times)", repeatPrevPermutation vdesc useSize)
, ("prevPermutationBijective (random perm of size n, n times)", repeatPrevPermutation vrnd useSize)
, ("baseline for *Bijective (just copying the vector of size n)", V.thaw vrnd >> return ())
]

-- | Given a PRNG and a length @n@, generate a random permutation of @[0..n-1]@.
randomPermutationWith :: (StatefulGen g IO) => g -> Int -> IO (V.Vector Int)
randomPermutationWith gen n = do
v <- M.generate n id
V.forM_ (V.generate (n-1) id) $ \ !i -> do
j <- uniformRM (i,n-1) gen
M.swap v i j
V.unsafeFreeze v

-- | Given @useSize@ benchmark option, compute the largest @n <= 12@ such that @n! <= useSize@.
-- Repeat-nextPermutation-until-end benchmark will use @n@ as the length of the vector.
-- Note that 12 is the largest @n@ such that @n!@ can be represented as an 'Int32'.
useSizeToPermLen :: Int -> Int
useSizeToPermLen us = case V.findIndex (> max 0 us) $ V.scanl' (*) 1 $ V.generate 12 (+1) of
Just i -> i-1
Nothing -> 12

-- | A bijective version of @G.nextPermutation@ that reverses the vector
-- if it is already in descending order.
-- "Bijective" here means that the function forms a cycle over all permutations
-- of the vector's elements.
--
-- This has a nice property that should be benchmarked:
-- this function takes amortized constant time each call,
-- if successively called either Omega(n) times on a single vector having distinct elements,
-- or arbitrary times on a single vector initially in strictly ascending order.
nextPermutationBijective :: (G.MVector v a, Ord a) => v G.RealWorld a -> IO Bool
nextPermutationBijective v = do
res <- G.nextPermutation v
if res then return True else G.reverse v >> return False

-- | A bijective version of @G.prevPermutation@ that reverses the vector
-- if it is already in ascending order.
-- "Bijective" here means that the function forms a cycle over all permutations
-- of the vector's elements.
--
-- This has a nice property that should be benchmarked:
-- this function takes amortized constant time each call,
-- if successively called either Omega(n) times on a single vector having distinct elements,
-- or arbitrary times on a single vector initially in strictly descending order.
prevPermutationBijective :: (G.MVector v a, Ord a) => v G.RealWorld a -> IO Bool
prevPermutationBijective v = do
res <- G.prevPermutation v
if res then return True else G.reverse v >> return False

-- | Repeat @nextPermutation@ on @[0..n-1]@ until the end.
loopPermutations :: Int -> IO ()
loopPermutations n = do
v <- M.generate n id
let loop = do
res <- M.nextPermutation v
if res then loop else return ()
loop

-- | Repeat @prevPermutation@ on @[n-1,n-2..0]@ until the end.
loopRevPermutations :: Int -> IO ()
loopRevPermutations n = do
v <- M.generate n (n-1-)
let loop = do
res <- M.prevPermutation v
if res then loop else return ()
loop

-- | Repeat @nextPermutationBijective@ on a given vector given times.
repeatNextPermutation :: V.Vector Int -> Int -> IO ()
repeatNextPermutation !v !n = do
!mv <- V.thaw v
let loop !i | i <= 0 = return ()
loop !i = do
_ <- nextPermutationBijective mv
loop (i-1)
loop n

-- | Repeat @prevPermutationBijective@ on a given vector given times.
repeatPrevPermutation :: V.Vector Int -> Int -> IO ()
repeatPrevPermutation !v !n = do
!mv <- V.thaw v
let loop !i | i <= 0 = return ()
loop !i = do
_ <- prevPermutationBijective mv
loop (i-1)
loop n
23 changes: 13 additions & 10 deletions vector/benchmarks/Main.hs
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
{-# LANGUAGE BangPatterns #-}
module Main where

import Bench.Vector.Algo.MutableSet (mutableSet)
import Bench.Vector.Algo.ListRank (listRank)
import Bench.Vector.Algo.Rootfix (rootfix)
import Bench.Vector.Algo.Leaffix (leaffix)
import Bench.Vector.Algo.AwShCC (awshcc)
import Bench.Vector.Algo.HybCC (hybcc)
import Bench.Vector.Algo.Quickhull (quickhull)
import Bench.Vector.Algo.Spectral (spectral)
import Bench.Vector.Algo.Tridiag (tridiag)
import Bench.Vector.Algo.FindIndexR (findIndexR, findIndexR_naive, findIndexR_manual)
import Bench.Vector.Algo.MutableSet (mutableSet)
import Bench.Vector.Algo.ListRank (listRank)
import Bench.Vector.Algo.Rootfix (rootfix)
import Bench.Vector.Algo.Leaffix (leaffix)
import Bench.Vector.Algo.AwShCC (awshcc)
import Bench.Vector.Algo.HybCC (hybcc)
import Bench.Vector.Algo.Quickhull (quickhull)
import Bench.Vector.Algo.Spectral (spectral)
import Bench.Vector.Algo.Tridiag (tridiag)
import Bench.Vector.Algo.FindIndexR (findIndexR, findIndexR_naive, findIndexR_manual)
import Bench.Vector.Algo.NextPermutation (generatePermTests)

import Bench.Vector.TestData.ParenTree (parenTree)
import Bench.Vector.TestData.Graph (randomGraph)
Expand Down Expand Up @@ -50,6 +51,7 @@ main = do
!ds <- randomVector useSize
!sp <- randomVector (floor $ sqrt $ fromIntegral useSize)
vi <- MV.new useSize
permTests <- generatePermTests gen useSize

defaultMainWithIngredients ingredients $ bgroup "All"
[ bench "listRank" $ whnf listRank useSize
Expand All @@ -66,4 +68,5 @@ main = do
, bench "findIndexR_manual" $ whnf findIndexR_manual ((<indexFindThreshold), as)
, bench "minimumOn" $ whnf (U.minimumOn (\x -> x*x*x)) as
, bench "maximumOn" $ whnf (U.maximumOn (\x -> x*x*x)) as
, bgroup "(next|prev)Permutation" $ map (\(name, act) -> bench name $ whnfIO act) permTests
]
11 changes: 11 additions & 0 deletions vector/changelog.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
# Changes in version 0.13.2.0

* We had some improvements on `*.Mutable.{next,prev}Permutation{,By}`
[#498](https://github.com/haskell/vector/pull/498):
* Add `*.Mutable.prevPermutation{,By}` and `*.Mutable.nextPermutationBy`
* Improve time performance. We may now expect good specialization supported by inlining.
The implementation has also been algorithmically updated: in the previous implementation
the full enumeration of all the permutations of `[1..n]` took Omega(n*n!), but it now takes O(n!).
* Add tests for `{next,prev}Permutation`
* Add benchmarks for `{next,prev}Permutation`

# Changes in version 0.13.1.0

* Specialized variants of `findIndexR` are reexported for all vector
Expand Down
112 changes: 88 additions & 24 deletions vector/src/Data/Vector/Generic/Mutable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ module Data.Vector.Generic.Mutable (
ifoldr, ifoldr', ifoldrM, ifoldrM',

-- * Modifying vectors
nextPermutation,
nextPermutation, nextPermutationBy,
prevPermutation, prevPermutationBy,

-- ** Filling and copying
set, copy, move, unsafeCopy, unsafeMove,
Expand Down Expand Up @@ -91,9 +92,10 @@ import Data.Vector.Internal.Check
import Control.Monad.Primitive ( PrimMonad(..), RealWorld, stToPrim )

import Prelude
( Ord, Monad, Bool(..), Int, Maybe(..), Either(..)
( Ord, Monad, Bool(..), Int, Maybe(..), Either(..), Ordering(..)
, return, otherwise, flip, const, seq, min, max, not, pure
, (>>=), (+), (-), (<), (<=), (>=), (==), (/=), (.), ($), (=<<), (>>), (<$>) )
, (>>=), (+), (-), (<), (<=), (>), (>=), (==), (/=), (.), ($), (=<<), (>>), (<$>) )
import Data.Bits ( Bits(shiftR) )

#include "vector.h"

Expand Down Expand Up @@ -1213,6 +1215,47 @@ partitionWithUnknown f s
-- Modifying vectors
-- -----------------


-- | Compute the (lexicographically) next permutation of the given vector in-place.
-- Returns False when the input is the last item in the enumeration, i.e., if it is in
-- weakly descending order. In this case the vector will not get updated,
-- as opposed to the behavior of the C++ function @std::next_permutation@.
nextPermutation :: (PrimMonad m, Ord e, MVector v e) => v (PrimState m) e -> m Bool
{-# INLINE nextPermutation #-}
nextPermutation = nextPermutationByLt (<)

-- | Compute the (lexicographically) next permutation of the given vector in-place,
-- using the provided comparison function.
-- Returns False when the input is the last item in the enumeration, i.e., if it is in
-- weakly descending order. In this case the vector will not get updated,
-- as opposed to the behavior of the C++ function @std::next_permutation@.
--
-- @since 0.13.2.0
nextPermutationBy :: (PrimMonad m, MVector v e) => (e -> e -> Ordering) -> v (PrimState m) e -> m Bool
{-# INLINE nextPermutationBy #-}
nextPermutationBy cmp = nextPermutationByLt (\x y -> cmp x y == LT)

-- | Compute the (lexicographically) previous permutation of the given vector in-place.
-- Returns False when the input is the last item in the enumeration, i.e., if it is in
-- weakly ascending order. In this case the vector will not get updated,
-- as opposed to the behavior of the C++ function @std::prev_permutation@.
--
-- @since 0.13.2.0
prevPermutation :: (PrimMonad m, Ord e, MVector v e) => v (PrimState m) e -> m Bool
{-# INLINE prevPermutation #-}
prevPermutation = nextPermutationByLt (>)

-- | Compute the (lexicographically) previous permutation of the given vector in-place,
-- using the provided comparison function.
-- Returns False when the input is the last item in the enumeration, i.e., if it is in
-- weakly ascending order. In this case the vector will not get updated,
-- as opposed to the behavior of the C++ function @std::prev_permutation@.
--
-- @since 0.13.2.0
prevPermutationBy :: (PrimMonad m, MVector v e) => (e -> e -> Ordering) -> v (PrimState m) e -> m Bool
{-# INLINE prevPermutationBy #-}
prevPermutationBy cmp = nextPermutationByLt (\x y -> cmp x y == GT)

{-
http://en.wikipedia.org/wiki/Permutation#Algorithms_to_generate_permutations
Expand All @@ -1224,30 +1267,51 @@ a given permutation. It changes the given permutation in-place.
2. Find the largest index l greater than k such that a[k] < a[l].
3. Swap the value of a[k] with that of a[l].
4. Reverse the sequence from a[k + 1] up to and including the final element a[n]
The algorithm has been updated to look up the k in Step 1 beginning from the
last of the vector; which renders the algorithm to achieve the average time
complexity of O(1) each call. The worst case time complexity is still O(n).
The orginal implementation, which scanned the vector from the left, had the
time complexity of O(n) on the best case.
-}

-- | Compute the (lexicographically) next permutation of the given vector in-place.
-- Returns False when the input is the last permutation.
nextPermutation :: (PrimMonad m,Ord e,MVector v e) => v (PrimState m) e -> m Bool
nextPermutation v
| dim < 2 = return False
| otherwise = do
val <- unsafeRead v 0
(k,l) <- loop val (-1) 0 val 1
if k < 0
then return False
else unsafeSwap v k l >>
reverse (unsafeSlice (k+1) (dim-k-1) v) >>
return True
where loop !kval !k !l !prev !i
| i == dim = return (k,l)
| otherwise = do
cur <- unsafeRead v i
-- TODO: make tuple unboxed
let (kval',k') = if prev < cur then (prev,i-1) else (kval,k)
l' = if kval' < cur then i else l
loop kval' k' l' cur (i+1)
dim = length v
-- Here, the first argument should be a less-than comparison function.
-- Returns False when the input is the last permutation; in this case the vector
-- will not get updated, as opposed to the behavior of the C++ function
-- @std::next_permutation@.
nextPermutationByLt :: (PrimMonad m, MVector v e) => (e -> e -> Bool) -> v (PrimState m) e -> m Bool
{-# INLINE nextPermutationByLt #-}
nextPermutationByLt lt v
| dim < 2 = return False
| otherwise = stToPrim $ do
!vlast <- unsafeRead v (dim - 1)
decrLoop (dim - 2) vlast
where
dim = length v
-- find the largest index k such that a[k] < a[k + 1], and then pass to the rest.
decrLoop !i !vi1 | i >= 0 = do
!vi <- unsafeRead v i
if vi `lt` vi1 then swapLoop i vi (i+1) vi1 dim else decrLoop (i-1) vi
decrLoop _ !_ = return False
-- find the largest index l greater than k such that a[k] < a[l], and do the rest.
swapLoop !k !vk = go
where
-- binary search.
go !l !vl !r | r - l <= 1 = do
-- Done; do the rest of the algorithm.
unsafeWrite v k vl
unsafeWrite v l vk
reverse $ unsafeSlice (k + 1) (dim - k - 1) v
return True
go !l !vl !r = do
!vmid <- unsafeRead v mid
if vk `lt` vmid
then go mid vmid r
else go l vl mid
where
!mid = l + (r - l) `shiftR` 1


-- $setup
-- >>> import Prelude ((*))
39 changes: 37 additions & 2 deletions vector/src/Data/Vector/Mutable.hs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ module Data.Vector.Mutable (
ifoldr, ifoldr', ifoldrM, ifoldrM',

-- * Modifying vectors
nextPermutation,
nextPermutation, nextPermutationBy,
prevPermutation, prevPermutationBy,

-- ** Filling and copying
set, copy, move, unsafeCopy, unsafeMove,
Expand Down Expand Up @@ -574,11 +575,45 @@ unsafeMove = G.unsafeMove
-- -----------------

-- | Compute the (lexicographically) next permutation of the given vector in-place.
-- Returns False when the input is the last permutation.
-- Returns False when the input is the last item in the enumeration, i.e., if it is in
-- weakly descending order. In this case the vector will not get updated,
-- as opposed to the behavior of the C++ function @std::next_permutation@.
nextPermutation :: (PrimMonad m, Ord e) => MVector (PrimState m) e -> m Bool
{-# INLINE nextPermutation #-}
nextPermutation = G.nextPermutation

-- | Compute the (lexicographically) next permutation of the given vector in-place,
-- using the provided comparison function.
-- Returns False when the input is the last item in the enumeration, i.e., if it is in
-- weakly descending order. In this case the vector will not get updated,
-- as opposed to the behavior of the C++ function @std::next_permutation@.
--
-- @since 0.13.2.0
nextPermutationBy :: PrimMonad m => (e -> e -> Ordering) -> MVector (PrimState m) e -> m Bool
{-# INLINE nextPermutationBy #-}
nextPermutationBy = G.nextPermutationBy

-- | Compute the (lexicographically) previous permutation of the given vector in-place.
-- Returns False when the input is the last item in the enumeration, i.e., if it is in
-- weakly ascending order. In this case the vector will not get updated,
-- as opposed to the behavior of the C++ function @std::prev_permutation@.
--
-- @since 0.13.2.0
prevPermutation :: (PrimMonad m, Ord e) => MVector (PrimState m) e -> m Bool
{-# INLINE prevPermutation #-}
prevPermutation = G.prevPermutation

-- | Compute the (lexicographically) previous permutation of the given vector in-place,
-- using the provided comparison function.
-- Returns False when the input is the last item in the enumeration, i.e., if it is in
-- weakly ascending order. In this case the vector will not get updated,
-- as opposed to the behavior of the C++ function @std::prev_permutation@.
--
-- @since 0.13.2.0
prevPermutationBy :: PrimMonad m => (e -> e -> Ordering) -> MVector (PrimState m) e -> m Bool
{-# INLINE prevPermutationBy #-}
prevPermutationBy = G.prevPermutationBy

-- Folds
-- -----

Expand Down
Loading

0 comments on commit eb60526

Please sign in to comment.