Search code examples
haskellrecursionmonadsmonad-transformersstate-monad

Refactor impure recursion with state monad?


I've been dissecting this one-liner solution for aoc day 14 and came across an elegant impure recursive solution:

def s(x,y):
    if y > h: return True
    if (x, y) in m: return False
    return next((r for d in (0,-1,1) if (r:=s(x+d,y+1))), None) or m.add((x,y))

full solution on godbolt

One way you could make this pure is by explicitly passing and returning the set m from the function s (i.e. s :: int -> int -> set -> (bool, set)).

However, I've also read about how the reader/writer/state monads save you from having to pass the extra parameter and handle the tuple result an am interested in porting this recursion to haskell.

I found a haskell solution on the reddit that looks like it may do the same recursion (as well as two more that don't).

fill :: (MArray a Bool (ST s), Ix i, Num i, Show i) => a (i, i) Bool -> i -> ST s (Int, Int)
fill blocks maxY = do
    counterAtMaxY <- newSTRef Nothing
    counter <- newSTRef 0
    let fill' (x, y) = readArray blocks (x, y) >>= flip bool (pure ()) do
            when (y == maxY) $ readSTRef counterAtMaxY >>= maybe
                (readSTRef counter >>= writeSTRef counterAtMaxY . Just) (const $ pure ())
            when (y <= maxY) $ fill' (x, y + 1) >> fill' (x - 1, y + 1) >> fill' (x + 1, y + 1)
            writeArray blocks (x, y) True >> modifySTRef' counter (+ 1)
    fill' (500, 0)
    counterAtMaxY <- readSTRef counterAtMaxY
    counter <- readSTRef counter
    pure (fromMaybe counter counterAtMaxY, counter)

full solution on godbolt

Could someone confirm that this indeed is a port of the python solution. If so could they baby me through following how the recursion is happening?

I still am not Haskell literate. I can kind of make out that fill' (500, 0) means m >>= \_ -> fill' (500, 0), which means discard the current state, and create a new monad independently (something gets preserved but I'm confused what)??. I also don't understand monad transformers at all.

The Haskell solution does part 2 of the question simultaneously, so maybe someone can factor that out so there's no confusion between the cartesian coordinates and the pair of ints containing the solution.


Solution

  • Below is a fairly close translation of your Python code to Haskell. Some remarks on the differences:

    • The global h becomes a local parameter, and m :: Set (Int, Int) gets passed implicitly in the State monad, accessed using get and modify.
    • There is no early return in Haskell (calling return/pure doesn't abort the rest of the function, you have to put it at the end of the block). On the other hand, if expressions must have an else clause, so that forces you to do the right thing anyway.
    • The generator expression can be written as a higher-order function which tries each action in a list, stopping as soon as one returns True.
    • the add function in Python returns None, which gets interpreted as False in conditionals. In Haskell we don't like this kind of overloading; instead, we explicitly attach the False value to the value-less action add (x, y), add (x, y) *> pure False.
    • Use execState to "run the monadic program" s h0 500 0 with an initial state m0, obtaining its final state. That "program" s h0 500 0 :: M Bool is actually a pure function Set (Int, Int) -> (Bool, Set (Int, Int)), and all execState does is to apply that to the initial state and project out the second component of the output pair. The point of the "state monad" is that such a function can be defined with the syntax of an imperative language ("do-notation").
    module Main where
    
    import Control.Monad.State
    import Data.Set (Set)
    import qualified Data.Set as Set
    
    type M = State (Set (Int, Int))
    
    s :: Int -> Int -> Int -> M Bool
    s h x y =
      if y > h then pure True
      else do
        m <- get
        if Set.member (x, y) m then
          pure False
        else
          orM ([s h (x+d) (y+1) | d <- [0, -1, 1]] ++ [add (x, y) *> pure False])
    
    orM :: Monad m => [m Bool] -> m Bool
    orM [] = pure False
    orM (x : xs) = do
      b <- x
      if b then pure True
      else orM xs
    
    add :: (Int, Int) -> M ()
    add (x, y) = modify (Set.insert (x, y))
    
    -- Example from https://adventofcode.com/2022/day/14
    
    m0 :: Set (Int, Int)
    m0 = vline 498 4 6 <> hline 498 496 6 <> hline 503 502 4 <> vline 502 4 9 <> hline 494 502 9
    
    vline, hline :: Int -> Int -> Int -> Set (Int, Int)
    vline x y1 y2 | y1 > y2 = vline x y2 y1
    vline x y1 y2 = Set.fromList [(x, y) | y <- [y1 .. y2]]
    
    hline x1 x2 y | x1 > x2 = hline x2 x1 y
    hline x1 x2 y = Set.fromList [(x, y) | x <- [x1 .. x2]]
    
    h0 :: Int
    h0 = 9
    
    main :: IO ()
    main =
      print (Set.size (execState (s h0 500 0) m0) - Set.size m0)
      -- Output: 24