Consider this function that generates a list for an arbitrary Monad
:
generateListM :: Monad m => Int -> (Int -> m a) -> m [a]
generateListM sz f = go 0
where go i | i < sz = do x <- f i
xs <- go (i + 1)
return (x:xs)
| otherwise = pure []
Implementation maybe isn't perfect, but it is presented here solely for demonstration of the desired effect, which is pretty straightforward. For example if a monad is a list well get list of lists:
λ> generateListM 3 (\i -> [0 :: Int64 .. fromIntegral i])
[[0,0,0],[0,0,1],[0,0,2],[0,1,0],[0,1,1],[0,1,2]]
What I'd like to do is to achieve the same affect, but for ByteArray
instead of a List. As it turns out, this is much trickier than I thought when I first stumbled upon this problem. The end goal is to use that generator to implement mapM
in massiv, but that is besides the point.
The approach that requires the least effort is to use a function generateM
from vector package while doing a bit of manual conversion. But as it turns out there is a way to achieve at least a factor of x2 performance gain with this neat little trick of handling the state token manually and interleaving it with the monad:
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UnboxedTuples #-}
import Data.Primitive.ByteArray
import Data.Primitive.Types
import qualified Data.Vector.Primitive as VP
import GHC.Int
import GHC.Magic
import GHC.Prim
-- | Can't `return` unlifted types, so we need a wrapper for the state and MutableByteArray
data MutableByteArrayState s = MutableByteArrayState !(State# s) !(MutableByteArray# s)
generatePrimM :: forall m a . (Prim a, Monad m) => Int -> (Int -> m a) -> m (VP.Vector a)
generatePrimM (I# sz#) f =
runRW# $ \s0# -> do
let go i# = do
case i# <# sz# of
0# ->
case newByteArray# (sz# *# sizeOf# (undefined :: a)) (noDuplicate# s0#) of
(# s1#, mba# #) -> return (MutableByteArrayState s1# mba#)
_ -> do
res <- f (I# i#)
MutableByteArrayState si# mba# <- go (i# +# 1#)
return (MutableByteArrayState (writeByteArray# mba# i# res si#) mba#)
MutableByteArrayState s# mba# <- go 0#
case unsafeFreezeByteArray# mba# s# of
(# _, ba# #) -> return (VP.Vector 0 (I# sz#) (ByteArray ba#))
We can use it in the same fashion as before, except now we'll get a primitive Vector
, which is backed by ByteArray
, which is what I really need:
λ> generatePrimM 3 (\i -> [0 :: Int64 .. fromIntegral i])
[[0,0,0],[0,0,1],[0,0,2],[0,1,0],[0,1,1],[0,1,2]]
This seems to work great, performs well for ghc version 8.0 and 8.2, except that there is a regression in 8.4 and 8.6, but that issue is orthogonal.
Finally I get to the actual question. Is this approach really safe? Is there some edge case that I am not aware of that could bite me later? Any other suggestions or opinions are welcome as well with regard to the above function.
PS. m
doesn't have to be restricted to a Monad
, an Applicative
would work just fine, but the example is a bit clearer when it is presented with do
syntax.
TLDR; From what I gathered so far, it does seem to be a safe way to generate a primitive Vector
in a way I originally proposed. Moreover, the use of noDuplicate#
is not really necessary, since all of the operations are idempotent and order of operations will not have an affect on the resulted array(s).
Disclosure: It's been over a year since I first thought about that problem. It was only last month that I tried to get back to it. Reason why I am saying this is because checking out primitive package now I noticed a new module Data.Primitive.PrimArray
to me. As @chi mentioned in the comments, there isn't really a need to drop down to the low-level primitives in order to get a solution, since it might already exist. Which contains exactly the function generatePrimArrayA, which was exactly what I was looking for (a bit simplified copy of the source code):
newtype STA a = STA {_runSTA :: forall s. MutableByteArray# s -> ST s (PrimArray a)}
runSTA :: forall a. Prim a => Int -> STA a -> PrimArray a
runSTA !sz =
\(STA m) -> runST $ newPrimArray sz >>= \(ar :: MutablePrimArray s a) -> m (unMutablePrimArray ar)
generatePrimArrayA :: (Applicative f, Prim a) => Int -> (Int -> f a) -> f (PrimArray a)
generatePrimArrayA len f =
let go !i
| i == len = pure $ STA $ \mary -> unsafeFreezePrimArray (MutablePrimArray mary)
| otherwise =
liftA2
(\b (STA m) -> STA $ \mary -> writePrimArray (MutablePrimArray mary) i b >> m mary)
(f i)
(go (i + 1))
in runSTA len <$> go 0
Just as a fun exercise if we go through the basic simplification with usual reduction rules we get a very similar thing to what I had in the first place:
generatePrimArrayA :: forall f a. (Applicative f, Prim a) => Int -> (Int -> f a) -> f (PrimArray a)
generatePrimArrayA !(I# n#) f =
let go i# = case i# <# n# of
0# -> pure $ \mary s# ->
case unsafeFreezeByteArray# mary s# of
(# s'#, arr'# #) -> (# s'#, PrimArray arr'# #)
_ -> liftA2
(\b m ->
\mary s ->
case writeByteArray# mary i# b s of
s'# -> m mary s'#)
(f (I# i#))
(go (i# +# 1#))
in (\m -> runRW# $ \s0# ->
case newByteArray# (n# *# sizeOf# (undefined :: a)) s0# of
(# s'#, arr# #) -> case m arr# s'# of
(# _, a #) -> a)
<$> go 0#
Here is my version adjusted for an Applicative
instead of a Monad
:
generatePrimM :: forall m a . (Prim a, Applicative m) => Int -> (Int -> m a) -> m (PrimArray a)
generatePrimM (I# sz#) f =
let go i# = case i# <# sz# of
0# -> runRW# $ \s0# ->
case newByteArray# (sz# *# sizeOf# (undefined :: a)) s0# of
(# s1#, mba# #) -> pure (MutableByteArrayState s1# mba#)
_ -> liftA2
(\b (MutableByteArrayState si# mba#) ->
MutableByteArrayState (writeByteArray# mba# i# b si#) mba#)
(f (I# i#))
(go (i# +# 1#))
in (\(MutableByteArrayState s# mba#) ->
case unsafeFreezeByteArray# mba# s# of
(# _, ba# #) -> PrimArray ba#) <$>
(go 0#)
Functionally and performance wise they are very close to each other, and in the end they will both produce exactly the same answer. The difference is what the inner loop go
produces in the end. The latter one will return an applicative containing the closure that can construct the MutableByteArray#
s, which will later be frozen. While the former has a loop that returns an applicative containing an action that will create a frozen ByteArray#
s, once an action that can create a MutableByteArray#
is supplied to it.
Nevertheless, the reason what makes both approaches safe is that each element of every produced array within the loop gets written to exactly once, and each MutableByteArray#
created does get frozen prior to getting returned by the generating function, but not before it finished writing to them.