Skip to content

Commit

Permalink
Straighten folds and scans. (#364)
Browse files Browse the repository at this point in the history
* Add strict right folds.

* Add property checks.

* Add benchmarks.

* Inline strictness checks.

* Straighten scans.

* Fix whitespace.

* Use `===` for equality.

* Use infix operator for brevity.

* Add bench marks for lazy scans.

* Use standard recursion schemes.

* Dodge import conflicts on older GHC versions.

* Final considerations according to the last review.

* Final considerations according to one more last review.

* Add bench mark for lazy accumulating maps.

* Throw away `mapAccum[LR]Chunks`.

Turns out we do not really need it. We thought we need it to implement `scan[lr]`, but actually `mapAccum[LR]` is enough.
  • Loading branch information
kindaro authored Aug 19, 2021
1 parent 05a09c3 commit 99b7ff6
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 32 deletions.
62 changes: 56 additions & 6 deletions Data/ByteString/Lazy.hs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ module Data.ByteString.Lazy (
foldl1,
foldl1',
foldr,
foldr',
foldr1,
foldr1',

-- ** Special folds
concat,
Expand All @@ -106,9 +108,9 @@ module Data.ByteString.Lazy (
-- * Building ByteStrings
-- ** Scans
scanl,
-- scanl1,
-- scanr,
-- scanr1,
scanl1,
scanr,
scanr1,

-- ** Accumulating maps
mapAccumL,
Expand Down Expand Up @@ -460,6 +462,14 @@ foldr :: (Word8 -> a -> a) -> a -> ByteString -> a
foldr k = foldrChunks (flip (S.foldr k))
{-# INLINE foldr #-}

-- | 'foldr'' is like 'foldr', but strict in the accumulator.
foldr' :: (Word8 -> a -> a) -> a -> ByteString -> a
foldr' f a = go
where
go Empty = a
go (Chunk c cs) = S.foldr' f (foldr' f a cs) c
{-# INLINE foldr' #-}

-- | 'foldl1' is a variant of 'foldl' that has no starting value
-- argument, and thus must be applied to non-empty 'ByteString's.
foldl1 :: (Word8 -> Word8 -> Word8) -> ByteString -> Word8
Expand All @@ -479,6 +489,13 @@ foldr1 f (Chunk c0 cs0) = go c0 cs0
where go c Empty = S.foldr1 f c
go c (Chunk c' cs) = S.foldr f (go c' cs) c

-- | 'foldr1'' is like 'foldr1', but strict in the accumulator.
foldr1' :: (Word8 -> Word8 -> Word8) -> ByteString -> Word8
foldr1' _ Empty = errorEmptyList "foldr1'"
foldr1' f (Chunk c0 cs0) = go c0 cs0
where go c Empty = S.foldr1' f c
go c (Chunk c' cs) = S.foldr' f (go c' cs) c

-- ---------------------------------------------------------------------
-- Special folds

Expand Down Expand Up @@ -617,11 +634,44 @@ scanl
-- ^ input of length n
-> ByteString
-- ^ output of length n+1
scanl f z = snd . foldl k (z,singleton z)
where
k (c,acc) a = let n = f c a in (n, acc `snoc` n)
scanl function = fmap (uncurry (flip snoc)) . mapAccumL (\x y -> (function x y, x))
{-# INLINE scanl #-}

-- | 'scanl1' is a variant of 'scanl' that has no starting value argument.
--
-- > scanl1 f [x1, x2, ...] == [x1, x1 `f` x2, ...]
scanl1 :: (Word8 -> Word8 -> Word8) -> ByteString -> ByteString
scanl1 function byteStream = case uncons byteStream of
Nothing -> Empty
Just (firstByte, remainingBytes) -> scanl function firstByte remainingBytes

-- | 'scanr' is similar to 'foldr', but returns a list of successive
-- reduced values from the right.
--
-- > scanr f z [..., x{n-1}, xn] == [..., x{n-1} `f` (xn `f` z), xn `f` z, z]
--
-- Note that
--
-- > head (scanr f z xs) == foldr f z xs
-- > last (scanr f z xs) == z
--
scanr
:: (Word8 -> Word8 -> Word8)
-- ^ element -> accumulator -> new accumulator
-> Word8
-- ^ starting value of accumulator
-> ByteString
-- ^ input of length n
-> ByteString
-- ^ output of length n+1
scanr function = fmap (uncurry cons) . mapAccumR (\x y -> (function y x, x))

-- | 'scanr1' is a variant of 'scanr' that has no starting value argument.
scanr1 :: (Word8 -> Word8 -> Word8) -> ByteString -> ByteString
scanr1 function byteStream = case unsnoc byteStream of
Nothing -> Empty
Just (initialBytes, lastByte) -> scanr function lastByte initialBytes

-- ---------------------------------------------------------------------
-- Unfolds and replicates

Expand Down
52 changes: 48 additions & 4 deletions Data/ByteString/Lazy/Char8.hs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ module Data.ByteString.Lazy.Char8 (
foldl1,
foldl1',
foldr,
foldr',
foldr1,
foldr1',

-- ** Special folds
concat,
Expand All @@ -84,9 +86,9 @@ module Data.ByteString.Lazy.Char8 (
-- * Building ByteStrings
-- ** Scans
scanl,
-- scanl1,
-- scanr,
-- scanr1,
scanl1,
scanr,
scanr1,

-- ** Accumulating maps
mapAccumL,
Expand Down Expand Up @@ -238,7 +240,7 @@ import Foreign.Storable (peek)
import Prelude hiding
(reverse,head,tail,last,init,null,length,map,lines,foldl,foldr,unlines
,concat,any,take,drop,splitAt,takeWhile,dropWhile,span,break,elem,filter
,unwords,words,maximum,minimum,all,concatMap,scanl,scanl1,foldl1,foldr1
,unwords,words,maximum,minimum,all,concatMap,scanl,scanl1,scanr,scanr1,foldl1,foldr1
,readFile,writeFile,appendFile,replicate,getContents,getLine,putStr,putStrLn
,zip,zipWith,unzip,notElem,repeat,iterate,interact,cycle)

Expand Down Expand Up @@ -347,6 +349,10 @@ foldr :: (Char -> a -> a) -> a -> ByteString -> a
foldr f = L.foldr (f . w2c)
{-# INLINE foldr #-}

-- | 'foldr'' is like 'foldr', but strict in the accumulator.
foldr' :: (Char -> a -> a) -> a -> ByteString -> a
foldr' f = L.foldr' (f . w2c)

-- | 'foldl1' is a variant of 'foldl' that has no starting value
-- argument, and thus must be applied to non-empty 'ByteString's.
foldl1 :: (Char -> Char -> Char) -> ByteString -> Char
Expand All @@ -363,6 +369,10 @@ foldr1 :: (Char -> Char -> Char) -> ByteString -> Char
foldr1 f ps = w2c (L.foldr1 (\x y -> c2w (f (w2c x) (w2c y))) ps)
{-# INLINE foldr1 #-}

-- | 'foldr1'' is like 'foldr1', but strict in the accumulator.
foldr1' :: (Char -> Char -> Char) -> ByteString -> Char
foldr1' f ps = w2c (L.foldr1' (\x y -> c2w (f (w2c x) (w2c y))) ps)

-- | Map a function over a 'ByteString' and concatenate the results
concatMap :: (Char -> ByteString) -> ByteString -> ByteString
concatMap f = L.concatMap (f . w2c)
Expand Down Expand Up @@ -404,6 +414,40 @@ minimum = w2c . L.minimum
scanl :: (Char -> Char -> Char) -> Char -> ByteString -> ByteString
scanl f z = L.scanl (\a b -> c2w (f (w2c a) (w2c b))) (c2w z)

-- | 'scanl1' is a variant of 'scanl' that has no starting value argument.
--
-- > scanl1 f [x1, x2, ...] == [x1, x1 `f` x2, ...]
scanl1 :: (Char -> Char -> Char) -> ByteString -> ByteString
scanl1 f = L.scanl1 f'
where f' accumulator value = c2w (f (w2c accumulator) (w2c value))

-- | 'scanr' is similar to 'foldr', but returns a list of successive
-- reduced values from the right.
--
-- > scanr f z [..., x{n-1}, xn] == [..., x{n-1} `f` (xn `f` z), xn `f` z, z]
--
-- Note that
--
-- > head (scanr f z xs) == foldr f z xs
-- > last (scanr f z xs) == z
--
scanr
:: (Char -> Char -> Char)
-- ^ element -> accumulator -> new accumulator
-> Char
-- ^ starting value of accumulator
-> ByteString
-- ^ input of length n
-> ByteString
-- ^ output of length n+1
scanr f = L.scanr f' . c2w
where f' accumulator value = c2w (f (w2c accumulator) (w2c value))

-- | 'scanr1' is a variant of 'scanr' that has no starting value argument.
scanr1 :: (Char -> Char -> Char) -> ByteString -> ByteString
scanr1 f = L.scanr1 f'
where f' accumulator value = c2w (f (w2c accumulator) (w2c value))

-- | The 'mapAccumL' function behaves like a combination of 'map' and
-- 'foldl'; it applies a function to each element of a ByteString,
-- passing an accumulating parameter from left to right, and returning a
Expand Down
56 changes: 40 additions & 16 deletions bench/BenchAll.hs
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,9 @@ sortInputs = map (`S.take` S.pack [122, 121 .. 32]) [10..25]
foldInputs :: [S.ByteString]
foldInputs = map (\k -> S.pack $ if k <= 6 then take (2 ^ k) [32..95] else concat (replicate (2 ^ (k - 6)) [32..95])) [0..16]

foldInputsLazy :: [L.ByteString]
foldInputsLazy = map (\k -> L.pack $ if k <= 6 then take (2 ^ k) [32..95] else concat (replicate (2 ^ (k - 6)) [32..95])) [0..16]

zeroes :: L.ByteString
zeroes = L.replicate 10000 0

Expand Down Expand Up @@ -401,22 +404,43 @@ main = do
, bench "one huge word" $ nf S8.words byteStringData
]
, bgroup "folds"
[ bgroup "foldl'" $ map (\s -> bench (show $ S.length s) $
nf (S.foldl' (\acc x -> acc + fromIntegral x) (0 :: Int)) s) foldInputs
, bgroup "foldr'" $ map (\s -> bench (show $ S.length s) $
nf (S.foldr' (\x acc -> fromIntegral x + acc) (0 :: Int)) s) foldInputs
, bgroup "unfoldrN" $ map (\s -> bench (show $ S.length s) $
nf (S.unfoldrN (S.length s) (\a -> Just (a, a + 1))) 0) foldInputs
, bgroup "mapAccumL" $ map (\s -> bench (show $ S.length s) $
nf (S.mapAccumL (\acc x -> (acc + fromIntegral x, succ x)) (0 :: Int)) s) foldInputs
, bgroup "mapAccumR" $ map (\s -> bench (show $ S.length s) $
nf (S.mapAccumR (\acc x -> (fromIntegral x + acc, succ x)) (0 :: Int)) s) foldInputs
, bgroup "scanl" $ map (\s -> bench (show $ S.length s) $
nf (S.scanl (+) 0) s) foldInputs
, bgroup "scanr" $ map (\s -> bench (show $ S.length s) $
nf (S.scanr (+) 0) s) foldInputs
, bgroup "filter" $ map (\s -> bench (show $ S.length s) $
nf (S.filter odd) s) foldInputs
[ bgroup "strict"
[ bgroup "foldl'" $ map (\s -> bench (show $ S.length s) $
nf (S.foldl' (\acc x -> acc + fromIntegral x) (0 :: Int)) s) foldInputs
, bgroup "foldr'" $ map (\s -> bench (show $ S.length s) $
nf (S.foldr' (\x acc -> fromIntegral x + acc) (0 :: Int)) s) foldInputs
, bgroup "foldr1'" $ map (\s -> bench (show $ S.length s) $
nf (S.foldr1' (\x acc -> fromIntegral x + acc)) s) foldInputs
, bgroup "unfoldrN" $ map (\s -> bench (show $ S.length s) $
nf (S.unfoldrN (S.length s) (\a -> Just (a, a + 1))) 0) foldInputs
, bgroup "mapAccumL" $ map (\s -> bench (show $ S.length s) $
nf (S.mapAccumL (\acc x -> (acc + fromIntegral x, succ x)) (0 :: Int)) s) foldInputs
, bgroup "mapAccumR" $ map (\s -> bench (show $ S.length s) $
nf (S.mapAccumR (\acc x -> (fromIntegral x + acc, succ x)) (0 :: Int)) s) foldInputs
, bgroup "scanl" $ map (\s -> bench (show $ S.length s) $
nf (S.scanl (+) 0) s) foldInputs
, bgroup "scanr" $ map (\s -> bench (show $ S.length s) $
nf (S.scanr (+) 0) s) foldInputs
, bgroup "filter" $ map (\s -> bench (show $ S.length s) $
nf (S.filter odd) s) foldInputs
]
, bgroup "lazy"
[ bgroup "foldl'" $ map (\s -> bench (show $ L.length s) $
nf (L.foldl' (\acc x -> acc + fromIntegral x) (0 :: Int)) s) foldInputsLazy
, bgroup "foldr'" $ map (\s -> bench (show $ L.length s) $
nf (L.foldr' (\x acc -> fromIntegral x + acc) (0 :: Int)) s) foldInputsLazy
, bgroup "foldr1'" $ map (\s -> bench (show $ L.length s) $
nf (L.foldr1' (\x acc -> fromIntegral x + acc)) s) foldInputsLazy
, bgroup "mapAccumL" $ map (\s -> bench (show $ L.length s) $
nf (L.mapAccumL (\acc x -> (acc + fromIntegral x, succ x)) (0 :: Int)) s) foldInputsLazy
, bgroup "mapAccumR" $ map (\s -> bench (show $ L.length s) $
nf (L.mapAccumR (\acc x -> (fromIntegral x + acc, succ x)) (0 :: Int)) s) foldInputsLazy
, bgroup "scanl" $ map (\s -> bench (show $ L.length s) $
nf (L.scanl (+) 0) s) foldInputsLazy
, bgroup "scanr" $ map (\s -> bench (show $ L.length s) $
nf (L.scanr (+) 0) s) foldInputsLazy
]

]
, bgroup "findIndexOrLength"
[ bench "takeWhile" $ nf (L.takeWhile even) zeroes
Expand Down
40 changes: 39 additions & 1 deletion tests/Properties.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@ import Control.Concurrent
import Control.Exception
import System.Posix.Internals (c_unlink)

import qualified Data.List as List
import Data.Char
import Data.Word
import Data.Maybe
import Data.Int (Int64)
import Data.Monoid
import Data.Semigroup
import GHC.Exts (Int(..), newPinnedByteArray#, unsafeFreezeByteArray#)
import GHC.ST (ST(..), runST)
Expand Down Expand Up @@ -463,6 +463,12 @@ short_tests =
, testProperty "pinned" prop_short_pinned
]

------------------------------------------------------------------------
-- Strictness checks.

explosiveTail :: L.ByteString -> L.ByteString
explosiveTail = (`L.append` error "Tail of this byte string is undefined!")

------------------------------------------------------------------------
-- The entry point

Expand All @@ -475,6 +481,7 @@ main = defaultMain $ testGroup "All"
, testGroup "Misc" misc_tests
, testGroup "IO" io_tests
, testGroup "Short" short_tests
, testGroup "Strictness" strictness_checks
]

io_tests =
Expand Down Expand Up @@ -535,5 +542,36 @@ misc_tests =
, testProperty "readIntegerUnsafe" prop_readIntegerUnsafe
]

strictness_checks =
[ testGroup "Lazy Word8"
[ testProperty "foldr is lazy" $ \ xs ->
List.genericTake (L.length xs) (L.foldr (:) [ ] (explosiveTail xs)) === L.unpack xs
, testProperty "foldr' is strict" $ expectFailure $ \ xs ys ->
List.genericTake (L.length xs) (L.foldr' (:) [ ] (explosiveTail (xs <> ys))) === L.unpack xs
, testProperty "foldr1 is lazy" $ \ xs -> L.length xs > 0 ==>
L.foldr1 const (explosiveTail (xs <> L.singleton 1)) === L.head xs
, testProperty "foldr1' is strict" $ expectFailure $ \ xs ys -> L.length xs > 0 ==>
L.foldr1' const (explosiveTail (xs <> L.singleton 1 <> ys)) === L.head xs
, testProperty "scanl is lazy" $ \ xs ->
L.take (L.length xs + 1) (L.scanl (+) 0 (explosiveTail (xs <> L.singleton 1))) === (L.pack . fmap (L.foldr (+) 0) . L.inits) xs
, testProperty "scanl1 is lazy" $ \ xs -> L.length xs > 0 ==>
L.take (L.length xs) (L.scanl1 (+) (explosiveTail (xs <> L.singleton 1))) === (L.pack . fmap (L.foldr1 (+)) . tail . L.inits) xs
]
, testGroup "Lazy Char"
[ testProperty "foldr is lazy" $ \ xs ->
List.genericTake (D.length xs) (D.foldr (:) [ ] (explosiveTail xs)) === D.unpack xs
, testProperty "foldr' is strict" $ expectFailure $ \ xs ys ->
List.genericTake (D.length xs) (D.foldr' (:) [ ] (explosiveTail (xs <> ys))) === D.unpack xs
, testProperty "foldr1 is lazy" $ \ xs -> D.length xs > 0 ==>
D.foldr1 const (explosiveTail (xs <> D.singleton 'x')) === D.head xs
, testProperty "foldr1' is strict" $ expectFailure $ \ xs ys -> D.length xs > 0 ==>
D.foldr1' const (explosiveTail (xs <> D.singleton 'x' <> ys)) === D.head xs
, testProperty "scanl is lazy" $ \ xs -> let char1 +. char2 = toEnum (fromEnum char1 + fromEnum char2) in
D.take (D.length xs + 1) (D.scanl (+.) '\NUL' (explosiveTail (xs <> D.singleton '\SOH'))) === (D.pack . fmap (D.foldr (+.) '\NUL') . D.inits) xs
, testProperty "scanl1 is lazy" $ \ xs -> D.length xs > 0 ==> let char1 +. char2 = toEnum (fromEnum char1 + fromEnum char2) in
D.take (D.length xs) (D.scanl1 (+.) (explosiveTail (xs <> D.singleton '\SOH'))) === (D.pack . fmap (D.foldr1 (+.)) . tail . D.inits) xs
]
]

removeFile :: String -> IO ()
removeFile fn = void $ withCString fn c_unlink
7 changes: 2 additions & 5 deletions tests/Properties/ByteString.hs
Original file line number Diff line number Diff line change
Expand Up @@ -410,10 +410,8 @@ tests =
\f (toElem -> c) x -> B.foldl' ((toElem .) . f) c x === foldl' ((toElem .) . f) c (B.unpack x)
, testProperty "foldr" $
\f (toElem -> c) x -> B.foldr ((toElem .) . f) c x === foldr ((toElem .) . f) c (B.unpack x)
#ifndef BYTESTRING_LAZY
, testProperty "foldr'" $
\f (toElem -> c) x -> B.foldr' ((toElem .) . f) c x === foldr' ((toElem .) . f) c (B.unpack x)
#endif

, testProperty "foldl cons" $
\x -> B.foldl (flip B.cons) B.empty x === B.reverse x
Expand All @@ -432,10 +430,8 @@ tests =
\f x -> not (B.null x) ==> B.foldl1' ((toElem .) . f) x === List.foldl1' ((toElem .) . f) (B.unpack x)
, testProperty "foldr1" $
\f x -> not (B.null x) ==> B.foldr1 ((toElem .) . f) x === foldr1 ((toElem .) . f) (B.unpack x)
#ifndef BYTESTRING_LAZY
, testProperty "foldr1'" $ -- there is not Data.List.foldr1'
\f x -> not (B.null x) ==> B.foldr1' ((toElem .) . f) x === foldr1 ((toElem .) . f) (B.unpack x)
#endif

, testProperty "foldl1 const" $
\x -> not (B.null x) ==> B.foldl1 const x === B.head x
Expand All @@ -455,7 +451,6 @@ tests =
, testProperty "scanl foldl" $
\f (toElem -> c) x -> not (B.null x) ==> B.last (B.scanl ((toElem .) . f) c x) === B.foldl ((toElem .) . f) c x

#ifndef BYTESTRING_LAZY
, testProperty "scanr" $
\f (toElem -> c) x -> B.unpack (B.scanr ((toElem .) . f) c x) === scanr ((toElem .) . f) c (B.unpack x)
, testProperty "scanl1" $
Expand All @@ -466,6 +461,8 @@ tests =
\f x -> B.unpack (B.scanr1 ((toElem .) . f) x) === scanr1 ((toElem .) . f) (B.unpack x)
, testProperty "scanr1 empty" $
\f -> B.scanr1 f B.empty === B.empty

#ifndef BYTESTRING_LAZY
, testProperty "sort" $
\x -> B.unpack (B.sort x) === List.sort (B.unpack x)
#endif
Expand Down

0 comments on commit 99b7ff6

Please sign in to comment.