Search code examples
performancehaskellgeneratorcombinatorics

How to efficiently generate all interleavings of lists in Haskell?


What's the most efficient way to generate all lists obtained by interleaving the elements of a given list of lists in Haskell?

https://stackoverflow.com/a/41929156 proposes the following code:

interleavings :: [[a]] -> [[a]]
interleavings = go . filter (not . null)
  where
    go [] = [[]]
    go xss = do
      (xssl, x : xs, xssr) <- zippers xss
      (x :) <$> interleavings ([xs | not (null xs)] ++ xssl ++ xssr)
    zippers :: [a] -> [([a], a, [a])]
    zippers = go' []
      where
        go' l (h : r) = (l, h, r) : go' (h : l) r
        go' _ [] = []
ghci> interleavings [[1,2,3],[4,5],[6]]
[[1,2,3,4,5,6],[1,2,3,4,6,5],[1,2,3,6,4,5],[1,2,4,5,3,6],[1,2,4,5,6,3],[1,2,4,3,5,6],[1,2,4,3,6,5],[1,2,4,6,3,5],[1,2,4,6,5,3],[1,2,6,4,5,3],[1,2,6,4,3,5],[1,2,6,3,4,5],[1,4,5,2,3,6],[1,4,5,2,6,3],[1,4,5,6,2,3],[1,4,2,3,5,6],[1,4,2,3,6,5],[1,4,2,5,3,6],[1,4,2,5,6,3],[1,4,2,6,5,3],[1,4,2,6,3,5],[1,4,6,2,3,5],[1,4,6,2,5,3],[1,4,6,5,2,3],[1,6,4,5,2,3],[1,6,4,2,3,5],[1,6,4,2,5,3],[1,6,2,3,4,5],[1,6,2,4,5,3],[1,6,2,4,3,5],[4,5,1,2,3,6],[4,5,1,2,6,3],[4,5,1,6,2,3],[4,5,6,1,2,3],[4,1,2,3,5,6],[4,1,2,3,6,5],[4,1,2,5,3,6],[4,1,2,5,6,3],[4,1,2,6,5,3],[4,1,2,6,3,5],[4,1,5,2,3,6],[4,1,5,2,6,3],[4,1,5,6,2,3],[4,1,6,5,2,3],[4,1,6,2,3,5],[4,1,6,2,5,3],[4,6,1,2,3,5],[4,6,1,2,5,3],[4,6,1,5,2,3],[4,6,5,1,2,3],[6,4,5,1,2,3],[6,4,1,2,3,5],[6,4,1,2,5,3],[6,4,1,5,2,3],[6,1,2,3,4,5],[6,1,2,4,5,3],[6,1,2,4,3,5],[6,1,4,5,2,3],[6,1,4,2,3,5],[6,1,4,2,5,3]]

This is useful for instance for concurrent tests that try out all interleaving of program instructions.

But is there a more efficient way to do this considering Haskell's lazy evaluation, and the fact that we are using singly-linked lists? What if we didn't need the entire result in memory at the same time, but we instead just needed to evaluate a function on each interleaving?


Solution

  • You can make it slightly (33% in my tests) faster by using arrays arrays.

    Setup:

    -- Main.hs
    import Data.Primitive.SmallArray
    import Data.Primitive.PrimArray
    import Data.Primitive (Prim)
    import Control.Monad
    
    smallToPrim :: Prim b => (a -> b) -> SmallArray a -> PrimArray b
    smallToPrim f xs = runPrimArray $ do
      let n = length xs
      s <- newPrimArray n
      let
        go i 
          | i < n = do
            writePrimArray s i (f (indexSmallArray xs i))
            go (i + 1)
          | otherwise = pure ()
      go 0
      pure s
    
    -- creates a new array where the ith element is one more than in the given array
    increment :: Int -> PrimArray Int -> PrimArray Int
    increment i xs = runPrimArray $ do
      let n = sizeofPrimArray xs
      s <- newPrimArray n
      copyPrimArray s 0 xs 0 n
      x <- readPrimArray s i
      writePrimArray s i (x + 1)
      return s
    

    Main function:

    interleavings :: SmallArray (SmallArray a) -> [[a]]
    interleavings inputs = go id zeros where
    
      -- To compute all interleavings for an array [xs0, xs1, ..., xsn] we start
      -- with an array of indices that are all initialised to 0. Then we
      -- iteratively pick an index, add the next element from that input array to
      -- the current interleaving and increment the index. We repeat that until we
      -- have added all the elements from the input lists to the interleaving.  To
      -- get all possible interleavings we pick the index nondeterministically
      -- using the list monad.
    
      n = length inputs
      zeros = replicatePrimArray n 0
      end = smallToPrim length inputs
    
      -- acc is the particular interleaving we are working on right now
      go !acc !indices
        | indices == end = [acc []] 
        | otherwise = do -- list monad for nondeterminism
          -- pick one of the indices
          i <- [0 .. n - 1]
          let j = indexPrimArray indices i
          -- make sure that its index is within bounds
          guard (j < indexPrimArray end i)
          -- select that element from the input
          let !x = indexSmallArray (indexSmallArray inputs i) j
          -- add the element to the current interleaving and increment the corresponding index
          go (acc . (x :)) (increment i indices)
    

    Test and benchmark:

    test :: [[Int]]
    test = [[1,2,3],[4,5],[6]] 
    
    bench :: [[Int]]
    bench = replicate 4 [1 :: Int .. 4]
    
    main :: IO ()
    main = print $ length $ interleavings $ smallArrayFromList $ map smallArrayFromList bench
    

    Compile with ghc -O2 Main.hs

    Result:

    $ ./Main +RTS -s
    63063000
      89,190,505,480 bytes allocated in the heap
          34,217,632 bytes copied during GC
              44,328 bytes maximum residency (4 sample(s))
              29,400 bytes maximum slop
                   6 MiB total memory in use (0 MiB lost due to fragmentation)
    
                                         Tot time (elapsed)  Avg pause  Max pause
      Gen  0     21419 colls,     0 par    0.046s   0.049s     0.0000s    0.0000s
      Gen  1         4 colls,     0 par    0.000s   0.000s     0.0001s    0.0001s
    
      INIT    time    0.000s  (  0.000s elapsed)
      MUT     time    8.872s  (  8.864s elapsed)
      GC      time    0.046s  (  0.049s elapsed)
      EXIT    time    0.000s  (  0.006s elapsed)
      Total   time    8.918s  (  8.920s elapsed)
    
      %GC     time       0.0%  (0.0% elapsed)
    
      Alloc rate    10,053,049,984 bytes per MUT second
    
      Productivity  99.5% of total user, 99.4% of total elapsed