-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathMutable.hs
304 lines (276 loc) · 9.26 KB
/
Mutable.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
{-# LANGUAGE CPP #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
#ifndef BITVEC_THREADSAFE
module Data.Bit.Mutable
#else
module Data.Bit.MutableTS
#endif
( castFromWordsM
, castToWordsM
, cloneToWordsM
, cloneToWords8M
, zipInPlace
, mapInPlace
, invertInPlace
, selectBitsInPlace
, excludeBitsInPlace
, reverseInPlace
) where
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
#ifndef BITVEC_THREADSAFE
import Data.Bit.Internal
#else
import Data.Bit.InternalTS
#endif
import Data.Bit.Utils
import Data.Bits
import Data.Primitive.ByteArray
import qualified Data.Vector.Primitive as P
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as MU
import Data.Word
-- | Cast a vector of words to a vector of bits.
-- Cf. 'Data.Bit.castFromWords'.
castFromWordsM :: MVector s Word -> MVector s Bit
castFromWordsM (MU.MV_Word (P.MVector off len ws)) =
BitMVec (mulWordSize off) (mulWordSize len) ws
-- | Try to cast a vector of bits to a vector of words.
-- It succeeds if a vector of bits is aligned.
-- Use 'cloneToWordsM' otherwise.
-- Cf. 'Data.Bit.castToWords'.
castToWordsM :: MVector s Bit -> Maybe (MVector s Word)
castToWordsM (BitMVec s n ws)
| aligned s, aligned n
= Just $ MU.MV_Word $ P.MVector (divWordSize s) (divWordSize n) ws
| otherwise = Nothing
-- | Clone a vector of bits to a new unboxed vector of words.
-- If the bits don't completely fill the words, the last word will be zero-padded.
-- Cf. 'Data.Bit.cloneToWords'.
cloneToWordsM
:: PrimMonad m
=> MVector (PrimState m) Bit
-> m (MVector (PrimState m) Word)
cloneToWordsM v = do
let lenBits = MU.length v
lenWords = nWords lenBits
w@(BitMVec _ _ arr) <- MU.unsafeNew (mulWordSize lenWords)
MU.unsafeCopy (MU.slice 0 lenBits w) v
MU.set (MU.slice lenBits (mulWordSize lenWords - lenBits) w) (Bit False)
pure $ MU.MV_Word $ P.MVector 0 lenWords arr
{-# INLINE cloneToWordsM #-}
-- | Clone a vector of bits to a new unboxed vector of 'Word8'.
-- If the bits don't completely fill the words, the last 'Word8' will be zero-padded.
-- Cf. 'Data.Bit.cloneToWords8'.
cloneToWords8M
:: PrimMonad m
=> MVector (PrimState m) Bit
-> m (MVector (PrimState m) Word8)
cloneToWords8M v = do
let lenBits = MU.length v
lenWords = (lenBits + 7) `shiftR` 3
w@(BitMVec _ _ arr) <- MU.unsafeNew (lenWords `shiftL` 3)
MU.unsafeCopy (MU.slice 0 lenBits w) v
MU.set (MU.slice lenBits (lenWords `shiftL` 3 - lenBits) w) (Bit False)
pure $ MU.MV_Word8 $ P.MVector 0 lenWords arr
{-# INLINE cloneToWords8M #-}
-- | Zip two vectors with the given function.
-- rewriting contents of the second argument.
-- Cf. 'Data.Bit.zipBits'.
--
-- >>> :set -XOverloadedLists
-- >>> import Data.Bits
-- >>> Data.Vector.Unboxed.modify (zipInPlace (.&.) [1,1,0]) [0,1,1]
-- [0,1,0]
--
-- __Warning__: if the immutable vector is shorter than the mutable one,
-- it is a caller's responsibility to trim the result:
--
-- >>> :set -XOverloadedLists
-- >>> import Data.Bits
-- >>> Data.Vector.Unboxed.modify (zipInPlace (.&.) [1,1,0]) [0,1,1,1,1,1]
-- [0,1,0,1,1,1] -- note trailing garbage
zipInPlace
:: forall m.
PrimMonad m
=> (forall a . Bits a => a -> a -> a)
-> Vector Bit
-> MVector (PrimState m) Bit
-> m ()
zipInPlace f (BitVec off l xs) (BitMVec off' l' ys) =
go (l `min` l') off off'
where
go :: Int -> Int -> Int -> m ()
go len offXs offYs
| shft == 0 =
go' len offXs (divWordSize offYs)
| len <= wordSize = do
y <- readWord vecYs 0
writeWord vecYs 0 (f x y)
| otherwise = do
y <- readByteArray ys base
modifyByteArray ys base (loMask shft) (f (x `unsafeShiftL` shft) y .&. hiMask shft)
go' (len - wordSize + shft) (offXs + wordSize - shft) (base + 1)
where
vecXs = BitVec offXs len xs
vecYs = BitMVec offYs len ys
x = indexWord vecXs 0
shft = modWordSize offYs
base = divWordSize offYs
go' :: Int -> Int -> Int -> m ()
go' len offXs offYsW = do
if shft == 0
then loopAligned offYsW
else loop offYsW (indexByteArray xs base)
when (modWordSize len /= 0) $ do
let ix = len - modWordSize len
let x = indexWord vecXs ix
y <- readWord vecYs ix
writeWord vecYs ix (f x y)
where
vecXs = BitVec offXs len xs
vecYs = BitMVec (mulWordSize offYsW) len ys
shft = modWordSize offXs
shft' = wordSize - shft
base = divWordSize offXs
base0 = base - offYsW
base1 = base0 + 1
iMax = divWordSize len + offYsW
loopAligned :: Int -> m ()
loopAligned !i
| i >= iMax = pure ()
| otherwise = do
let x = indexByteArray xs (base0 + i) :: Word
y <- readByteArray ys i
writeByteArray ys i (f x y)
loopAligned (i + 1)
loop :: Int -> Word -> m ()
loop !i !acc
| i >= iMax = pure ()
| otherwise = do
let accNew = indexByteArray xs (base1 + i)
x = (acc `unsafeShiftR` shft) .|. (accNew `unsafeShiftL` shft')
y <- readByteArray ys i
writeByteArray ys i (f x y)
loop (i + 1) accNew
{-# SPECIALIZE zipInPlace :: (forall a. Bits a => a -> a -> a) -> Vector Bit -> MVector s Bit -> ST s () #-}
{-# INLINE zipInPlace #-}
-- | Apply a function to a mutable vector bitwise,
-- rewriting its contents.
-- Cf. 'Data.Bit.mapBits'.
--
-- >>> :set -XOverloadedLists
-- >>> import Data.Bits
-- >>> Data.Vector.Unboxed.modify (mapInPlace complement) [0,1,1]
-- [1,0,0]
mapInPlace
:: PrimMonad m
=> (forall a . Bits a => a -> a)
-> U.MVector (PrimState m) Bit
-> m ()
mapInPlace f xs = case (unBit (f (Bit False)), unBit (f (Bit True))) of
(False, False) -> MU.set xs (Bit False)
(False, True) -> pure ()
(True, False) -> invertInPlace xs
(True, True) -> MU.set xs (Bit True)
{-# SPECIALIZE mapInPlace :: (forall a. Bits a => a -> a) -> MVector s Bit -> ST s () #-}
{-# INLINE mapInPlace #-}
-- | Invert (flip) all bits in-place.
--
-- >>> :set -XOverloadedLists
-- >>> Data.Vector.Unboxed.modify invertInPlace [0,1,0,1,0]
-- [1,0,1,0,1]
invertInPlace :: PrimMonad m => U.MVector (PrimState m) Bit -> m ()
invertInPlace xs = do
let n = MU.length xs
forM_ [0, wordSize .. n - 1] $ \i -> do
x <- readWord xs i
writeWord xs i (complement x)
{-# SPECIALIZE invertInPlace :: U.MVector s Bit -> ST s () #-}
-- | Same as 'Data.Bit.selectBits', but deposit
-- selected bits in-place. Returns a number of selected bits.
-- It is caller's responsibility to trim the result to this number.
--
-- >>> :set -XOverloadedLists
-- >>> import Control.Monad.ST (runST)
-- >>> import qualified Data.Vector.Unboxed as U
-- >>> runST $ do { vec <- U.unsafeThaw [1,1,0,0,1]; n <- selectBitsInPlace [0,1,0,1,1] vec; U.take n <$> U.unsafeFreeze vec }
-- [1,0,1]
--
selectBitsInPlace
:: PrimMonad m => U.Vector Bit -> U.MVector (PrimState m) Bit -> m Int
selectBitsInPlace is xs = loop 0 0
where
!n = min (U.length is) (MU.length xs)
loop !i !ct
| i >= n = pure ct
| otherwise = do
x <- readWord xs i
let !(nSet, x') = selectWord (masked (n - i) (indexWord is i)) x
writeWord xs ct x'
loop (i + wordSize) (ct + nSet)
-- | Same as 'Data.Bit.excludeBits', but deposit
-- excluded bits in-place. Returns a number of excluded bits.
-- It is caller's responsibility to trim the result to this number.
--
-- >>> :set -XOverloadedLists
-- >>> import Control.Monad.ST (runST)
-- >>> import qualified Data.Vector.Unboxed as U
-- >>> runST $ do { vec <- U.unsafeThaw [1,1,0,0,1]; n <- excludeBitsInPlace [0,1,0,1,1] vec; U.take n <$> U.unsafeFreeze vec }
-- [1,0]
--
excludeBitsInPlace
:: PrimMonad m => U.Vector Bit -> U.MVector (PrimState m) Bit -> m Int
excludeBitsInPlace is xs = loop 0 0
where
!n = min (U.length is) (MU.length xs)
loop !i !ct
| i >= n = pure ct
| otherwise = do
x <- readWord xs i
let !(nSet, x') =
selectWord (masked (n - i) (complement (indexWord is i))) x
writeWord xs ct x'
loop (i + wordSize) (ct + nSet)
-- | Reverse the order of bits in-place.
--
-- >>> :set -XOverloadedLists
-- >>> Data.Vector.Unboxed.modify reverseInPlace [1,1,0,1,0]
-- [0,1,0,1,1]
--
-- Consider using @vector-rotcev@ package
-- to reverse vectors in O(1) time.
reverseInPlace :: PrimMonad m => U.MVector (PrimState m) Bit -> m ()
reverseInPlace xs
| len == 0 = pure ()
| otherwise = loop 0
where
len = MU.length xs
loop !i
| i' <= j' = do
x <- readWord xs i
y <- readWord xs j'
writeWord xs i (reverseWord y)
writeWord xs j' (reverseWord x)
loop i'
| i' < j = do
let w = (j - i) `shiftR` 1
k = j - w
x <- readWord xs i
y <- readWord xs k
writeWord xs i (meld w (reversePartialWord w y) x)
writeWord xs k (meld w (reversePartialWord w x) y)
loop i'
| otherwise = do
let w = j - i
x <- readWord xs i
writeWord xs i (meld w (reversePartialWord w x) x)
where
!j = len - i
!i' = i + wordSize
!j' = j - wordSize
{-# SPECIALIZE reverseInPlace :: U.MVector s Bit -> ST s () #-}