What's the most efficient way to generate all lists obtained by interleaving the elements of a given list of lists in Haskell?
https://stackoverflow.com/a/41929156 proposes the following code:
interleavings :: [[a]] -> [[a]]
interleavings = go . filter (not . null)
where
go [] = [[]]
go xss = do
(xssl, x : xs, xssr) <- zippers xss
(x :) <$> interleavings ([xs | not (null xs)] ++ xssl ++ xssr)
zippers :: [a] -> [([a], a, [a])]
zippers = go' []
where
go' l (h : r) = (l, h, r) : go' (h : l) r
go' _ [] = []
ghci> interleavings [[1,2,3],[4,5],[6]]
[[1,2,3,4,5,6],[1,2,3,4,6,5],[1,2,3,6,4,5],[1,2,4,5,3,6],[1,2,4,5,6,3],[1,2,4,3,5,6],[1,2,4,3,6,5],[1,2,4,6,3,5],[1,2,4,6,5,3],[1,2,6,4,5,3],[1,2,6,4,3,5],[1,2,6,3,4,5],[1,4,5,2,3,6],[1,4,5,2,6,3],[1,4,5,6,2,3],[1,4,2,3,5,6],[1,4,2,3,6,5],[1,4,2,5,3,6],[1,4,2,5,6,3],[1,4,2,6,5,3],[1,4,2,6,3,5],[1,4,6,2,3,5],[1,4,6,2,5,3],[1,4,6,5,2,3],[1,6,4,5,2,3],[1,6,4,2,3,5],[1,6,4,2,5,3],[1,6,2,3,4,5],[1,6,2,4,5,3],[1,6,2,4,3,5],[4,5,1,2,3,6],[4,5,1,2,6,3],[4,5,1,6,2,3],[4,5,6,1,2,3],[4,1,2,3,5,6],[4,1,2,3,6,5],[4,1,2,5,3,6],[4,1,2,5,6,3],[4,1,2,6,5,3],[4,1,2,6,3,5],[4,1,5,2,3,6],[4,1,5,2,6,3],[4,1,5,6,2,3],[4,1,6,5,2,3],[4,1,6,2,3,5],[4,1,6,2,5,3],[4,6,1,2,3,5],[4,6,1,2,5,3],[4,6,1,5,2,3],[4,6,5,1,2,3],[6,4,5,1,2,3],[6,4,1,2,3,5],[6,4,1,2,5,3],[6,4,1,5,2,3],[6,1,2,3,4,5],[6,1,2,4,5,3],[6,1,2,4,3,5],[6,1,4,5,2,3],[6,1,4,2,3,5],[6,1,4,2,5,3]]
This is useful for instance for concurrent tests that try out all interleaving of program instructions.
But is there a more efficient way to do this considering Haskell's lazy evaluation, and the fact that we are using singly-linked lists? What if we didn't need the entire result in memory at the same time, but we instead just needed to evaluate a function on each interleaving?
You can make it slightly (33% in my tests) faster by using arrays arrays.
Setup:
-- Main.hs
import Data.Primitive.SmallArray
import Data.Primitive.PrimArray
import Data.Primitive (Prim)
import Control.Monad
smallToPrim :: Prim b => (a -> b) -> SmallArray a -> PrimArray b
smallToPrim f xs = runPrimArray $ do
let n = length xs
s <- newPrimArray n
let
go i
| i < n = do
writePrimArray s i (f (indexSmallArray xs i))
go (i + 1)
| otherwise = pure ()
go 0
pure s
-- creates a new array where the ith element is one more than in the given array
increment :: Int -> PrimArray Int -> PrimArray Int
increment i xs = runPrimArray $ do
let n = sizeofPrimArray xs
s <- newPrimArray n
copyPrimArray s 0 xs 0 n
x <- readPrimArray s i
writePrimArray s i (x + 1)
return s
Main function:
interleavings :: SmallArray (SmallArray a) -> [[a]]
interleavings inputs = go id zeros where
-- To compute all interleavings for an array [xs0, xs1, ..., xsn] we start
-- with an array of indices that are all initialised to 0. Then we
-- iteratively pick an index, add the next element from that input array to
-- the current interleaving and increment the index. We repeat that until we
-- have added all the elements from the input lists to the interleaving. To
-- get all possible interleavings we pick the index nondeterministically
-- using the list monad.
n = length inputs
zeros = replicatePrimArray n 0
end = smallToPrim length inputs
-- acc is the particular interleaving we are working on right now
go !acc !indices
| indices == end = [acc []]
| otherwise = do -- list monad for nondeterminism
-- pick one of the indices
i <- [0 .. n - 1]
let j = indexPrimArray indices i
-- make sure that its index is within bounds
guard (j < indexPrimArray end i)
-- select that element from the input
let !x = indexSmallArray (indexSmallArray inputs i) j
-- add the element to the current interleaving and increment the corresponding index
go (acc . (x :)) (increment i indices)
Test and benchmark:
test :: [[Int]]
test = [[1,2,3],[4,5],[6]]
bench :: [[Int]]
bench = replicate 4 [1 :: Int .. 4]
main :: IO ()
main = print $ length $ interleavings $ smallArrayFromList $ map smallArrayFromList bench
Compile with ghc -O2 Main.hs
Result:
$ ./Main +RTS -s
63063000
89,190,505,480 bytes allocated in the heap
34,217,632 bytes copied during GC
44,328 bytes maximum residency (4 sample(s))
29,400 bytes maximum slop
6 MiB total memory in use (0 MiB lost due to fragmentation)
Tot time (elapsed) Avg pause Max pause
Gen 0 21419 colls, 0 par 0.046s 0.049s 0.0000s 0.0000s
Gen 1 4 colls, 0 par 0.000s 0.000s 0.0001s 0.0001s
INIT time 0.000s ( 0.000s elapsed)
MUT time 8.872s ( 8.864s elapsed)
GC time 0.046s ( 0.049s elapsed)
EXIT time 0.000s ( 0.006s elapsed)
Total time 8.918s ( 8.920s elapsed)
%GC time 0.0% (0.0% elapsed)
Alloc rate 10,053,049,984 bytes per MUT second
Productivity 99.5% of total user, 99.4% of total elapsed