Skip to content

Commit

Permalink
Rewrite findIndex, findIndexEnd and map to not pass constant arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
Boarders committed Jan 16, 2021
1 parent a36cb2a commit 0d4b6cc
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 24 deletions.
48 changes: 27 additions & 21 deletions Data/ByteString.hs
Original file line number Diff line number Diff line change
Expand Up @@ -442,16 +442,18 @@ append = mappend
-- | /O(n)/ 'map' @f xs@ is the ByteString obtained by applying @f@ to each
-- element of @xs@.
map :: (Word8 -> Word8) -> ByteString -> ByteString
map f (BS fp len) = unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \a ->
create len $ map_ 0 a
map f (BS fp len) = unsafeDupablePerformIO $ unsafeWithForeignPtr fp $ \srcPtr ->
create len $ \dstPtr -> m srcPtr dstPtr
where
map_ :: Int -> Ptr Word8 -> Ptr Word8 -> IO ()
map_ !n !p1 !p2
| n >= len = return ()
| otherwise = do
x <- peekByteOff p1 n
pokeByteOff p2 n (f x)
map_ (n+1) p1 p2
m !p1 !p2 = map_ 0
where
map_ :: Int -> IO ()
map_ !n
| n >= len = return ()
| otherwise = do
x <- peekByteOff p1 n
pokeByteOff p2 n (f x)
map_ (n+1)
{-# INLINE map #-}

-- | /O(n)/ 'reverse' @xs@ efficiently returns the elements of @xs@ in reverse order.
Expand Down Expand Up @@ -1342,13 +1344,15 @@ count w (BS x m) = accursedUnutterablePerformIO $ unsafeWithForeignPtr x $ \p ->
-- returns the index of the first element in the ByteString
-- satisfying the predicate.
findIndex :: (Word8 -> Bool) -> ByteString -> Maybe Int
findIndex k (BS x l) = accursedUnutterablePerformIO $ withForeignPtr x $ \f -> go f 0
findIndex k (BS x l) = accursedUnutterablePerformIO $ withForeignPtr x g
where
go !ptr !n | n >= l = return Nothing
| otherwise = do w <- peek ptr
if k w
then return (Just n)
else go (ptr `plusPtr` 1) (n+1)
g !ptr = go 0
where
go !n | n >= l = return Nothing
| otherwise = do w <- peek $ ptr `plusPtr` n
if k w
then return (Just n)
else go (n+1)
{-# INLINE [1] findIndex #-}

-- | /O(n)/ The 'findIndexEnd' function takes a predicate and a 'ByteString' and
Expand All @@ -1357,13 +1361,15 @@ findIndex k (BS x l) = accursedUnutterablePerformIO $ withForeignPtr x $ \f -> g
--
-- @since 0.10.12.0
findIndexEnd :: (Word8 -> Bool) -> ByteString -> Maybe Int
findIndexEnd k (BS x l) = accursedUnutterablePerformIO $ withForeignPtr x $ \ f -> go f (l-1)
findIndexEnd k (BS x l) = accursedUnutterablePerformIO $ withForeignPtr x g
where
go !ptr !n | n < 0 = return Nothing
| otherwise = do w <- peekByteOff ptr n
if k w
then return (Just n)
else go ptr (n-1)
g !ptr = go (l-1)
where
go !n | n < 0 = return Nothing
| otherwise = do w <- peekByteOff ptr n
if k w
then return (Just n)
else go (n-1)
{-# INLINE findIndexEnd #-}

-- | /O(n)/ The 'findIndices' function extends 'findIndex', by returning the
Expand Down
21 changes: 18 additions & 3 deletions bench/BenchAll.hs
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,13 @@ zeroes = L.replicate 10000 0
zeroOneRepeating :: L.ByteString
zeroOneRepeating = L.take 10000 (L.cycle (L.pack [0,1]))


largeTraversalInput :: S.ByteString
largeTraversalInput = S.concat (replicate 10 byteStringData)

smallTraversalInput :: S.ByteString
smallTraversalInput = S8.pack "The quick brown fox"

main :: IO ()
main = do
mapM_ putStrLn sanityCheckInfo
Expand Down Expand Up @@ -424,8 +431,16 @@ main = do
, bench "groupBy (>=)" $ nf (L.groupBy (>=)) zeroes
, bench "groupBy (>)" $ nf (L.groupBy (>)) zeroes
]
, bgroup "findIndex"
[ bench "findIndices" $ nf (sum . S.findIndices even) byteStringData
, bench "find" $ nf (S.find (>= 9998)) byteStringData
, bgroup "findIndex_"
[ bench "findIndices" $ nf (sum . S.findIndices (\x -> x == 129 || x == 72)) byteStringData
, bench "find" $ nf (S.find (>= 198)) byteStringData
]
, bgroup "findIndexEnd"
[ bench "findIndexEnd" $ nf (S.findIndexEnd (<= 57)) byteStringData
, bench "elemIndexInd" $ nf (S.elemIndexEnd 42) byteStringData
]
, bgroup "traversals"
[ bench "map (+1)" $ nf (S.map (+ 1)) largeTraversalInput
, bench "map (+1)" $ nf (S.map (+ 1)) smallTraversalInput
]
]

0 comments on commit 0d4b6cc

Please sign in to comment.