Search code examples
haskelllazy-evaluationstrictness

Lazy state transformer consumes lazy list eagerly in 2D recursion


I'm using a state transformer to randomly sample a dataset at every point of a 2D recursive walk, which outputs a list of 2D grids of samples that together succeed a condition. I'd like to pull from the results lazily, but my approach instead exhausts the whole dataset at every point before I can pull the first result.

To be concrete, consider this program:

import Control.Monad ( sequence, liftM2 )
import Data.Functor.Identity
import Control.Monad.State.Lazy ( StateT(..), State(..), runState )

walk :: Int -> Int -> [State Int [Int]]
walk _ 0 = [return [0]]
walk 0 _ = [return [0]]
walk x y =
  let st :: [State Int Int]
      st = [StateT (\s -> Identity (s, s + 1)), undefined]
      unst :: [State Int Int] -- degenerate state tf
      unst = [return 1, undefined]
  in map (\m_z -> do
      z <- m_z
      fmap concat $ sequence [
          liftM2 (zipWith (\x y -> x + y + z)) a b -- for 1D: map (+z) <$> a
          | a <- walk x (y - 1) -- depth
          , b <- walk (x - 1) y -- breadth -- comment out for 1D
        ]
    ) st -- vs. unst

main :: IO ()
main = do
  std <- getStdGen
  putStrLn $ show $ head $ fst $ (`runState` 0) $ head $ walk 2 2

The program walks the rectangular grid from (x, y) to (0, 0) and sums all the results, including the value of one of the lists of State monads: either the non-trivial transformers st that read and advance their state, or the trivial transformers unst. Of interest is whether the algorithm explores past the heads of st and unst.

In the code as presented, it throws undefined. I chalked this up to a misdesign of my order of chaining the transformations, and in particular, a problem with the state handling, as using unst instead (i.e. decoupling the result from state transitions) does produce a result. However, I then found that a 1D recursion also preserves laziness even with the state transformer (remove the breadth step b <- walk... and swap the liftM2 block for fmap).

If we trace (show (x, y)), we also see that it does walk the whole grid before triggering:

$ cabal run
Build profile: -w ghc-8.6.5 -O1
...
(2,2)
(2,1)
(1,2)
(1,1)
(1,1)
sandbox: Prelude.undefined

I suspect that my use of sequence is at fault here, but as the choice of monad and the dimensionality of the walk affect its success, I can't say broadly that sequenceing the transformations is the source of strictness by itself.

What's causing the difference in strictness between 1D and 2D recursion here, and how can I achieve the laziness I want?


Solution

  • Consider the following simplified example:

    import Control.Monad.State.Lazy
    
    st :: [State Int Int]
    st = [state (\s -> (s, s + 1)), undefined]
    
    action1d = do
      a <- sequence st
      return $ map (2*) a
    
    action2d = do
      a <- sequence st
      b <- sequence st
      return $ zipWith (+) a b
    
    main :: IO ()
    main = do
      print $ head $ evalState action1d 0
      print $ head $ evalState action2d 0
    

    Here, in both the 1D and 2D calculations, the head of the result depends explicitly only on the heads of the inputs (just head a for the 1D action and both head a and head b for the 2D action). However, in the 2D calculation, there's an implicit dependency of b (even just its head) on the current state, and that state depends on the evaluation of the entirety of a, not just its head.

    You have a similar dependency in your example, though it's obscured by the use of lists of state actions.

    Let's say we wanted to run the action walk22_head = head $ walk 2 2 manually and inspect the first integer in the resulting list:

    main = print $ head $ evalState walk22_head
    

    Writing the elements of the state action list st explicitly:

    st1, st2 :: State Int Int
    st1 = state (\s -> (s, s+1))
    st2 = undefined
    

    we can write walk22_head as:

    walk22_head = do
      z <- st1
      a <- walk21_head
      b <- walk12_head
      return $ zipWith (\x y -> x + y + z) a b
    

    Note that this depends only on the defined state action st1 and the heads of walk 2 1 and walk 1 2. Those heads, in turn, can be written:

    walk21_head = do
      z <- st1
      a <- return [0] -- walk20_head
      b <- walk11_head
      return $ zipWith (\x y -> x + y + z) a b
    
    walk12_head = do
      z <- st1
      a <- walk11_head
      b <- return [0] -- walk02_head
      return $ zipWith (\x y -> x + y + z) a b
    

    Again, these depend only on the defined state action st1 and the head of walk 1 1.

    Now, let's try to write down a definition of walk11_head:

    walk11_head = do
      z <- st1
      a <- return [0]
      b <- return [0]
      return $ zipWith (\x y -> x + y + z) a b
    

    This depends only on the defined state action st1, so with these definitions in place, if we run main, we get a defined answer:

    > main
    10
    

    But these definitions aren't accurate! In each of walk 1 2 and walk 2 1, the head action is a sequence of actions, starting with the action that invokes walk11_head, but continuing with actions based on walk11_tail. So, more accurate definitions would be:

    walk21_head = do
      z <- st1
      a <- return [0] -- walk20_head
      b <- walk11_head
      _ <- walk11_tail  -- side effect of the sequennce
      return $ zipWith (\x y -> x + y + z) a b
    
    walk12_head = do
      z <- st1
      a <- walk11_head
      b <- return [0] -- walk02_head
      _ <- walk11_tail  -- side effect of the sequence
      return $ zipWith (\x y -> x + y + z) a b
    

    with:

    walk11_tail = do
      z <- undefined
      a <- return [0]
      b <- return [0]
      return [zipWith (\x y -> x + y + z) a b]
    

    With these definitions in place, there's no problem running walk12_head and walk21_head in isolation:

    > head $ evalState walk12_head 0
    1
    > head $ evalState walk21_head 0
    1
    

    The state side effects here are not needed to calculate the answer and so never invoked. But, it's not possible to run them both in sequence:

    > head $ evalState (walk12_head >> walk21_head) 0
    *** Exception: Prelude.undefined
    CallStack (from HasCallStack):
      error, called at libraries/base/GHC/Err.hs:78:14 in base:GHC.Err
      undefined, called at Lazy2D_2.hs:41:8 in main:Main
    

    Therefore, trying to run main fails for the same reason:

    > main
    *** Exception: Prelude.undefined
    CallStack (from HasCallStack):
      error, called at libraries/base/GHC/Err.hs:78:14 in base:GHC.Err
      undefined, called at Lazy2D_2.hs:41:8 in main:Main
    

    because, in calculating walk22_head, even the very beginning of walk21_head's calculation depends on the state side effect walk11_tail initiated by walk12_head.

    Your original walk definition behaves the same way as these mockups:

    > head $ evalState (head $ walk 1 2) 0
    1
    > head $ evalState (head $ walk 2 1) 0
    1
    > head $ evalState (head (walk 1 2) >> head (walk 2 1)) 0
    *** Exception: Prelude.undefined
    CallStack (from HasCallStack):
      error, called at libraries/base/GHC/Err.hs:78:14 in base:GHC.Err
      undefined, called at Lazy2D_0.hs:15:49 in main:Main
    > head $ evalState (head (walk 2 2)) 0
    *** Exception: Prelude.undefined
    CallStack (from HasCallStack):
      error, called at libraries/base/GHC/Err.hs:78:14 in base:GHC.Err
      undefined, called at Lazy2D_0.hs:15:49 in main:Main
    

    It's hard to say how to fix this. Your toy example was excellent for the purposes of illustrating the problem, but it's not clear how the state is used in your "real" problem and if head $ walk 2 1 really has a state dependency on the sequence of walk 1 1 actions induced by head $ walk 1 2.