Search code examples
haskellioclosuresdo-notationioref

Haskell IORef - an answer vs. a function to get an answer


I'm trying to understand how IORefs are really used, and I'm having trouble following the sample code I found on https://www.seas.upenn.edu/~cis194/spring15/lectures/12-unsafe.html

newCounter :: IO (IO Int)
newCounter = do
  r <- newIORef 0
  return $ do
    v <- readIORef r
    writeIORef r (v + 1)
    return v

printCounts :: IO ()
printCounts = do
  c <- newCounter
  print =<< c
  print =<< c
  print =<< c

When printCounts executes "c <- newCounter", why doesn't c get the result of doing the work in the newCounter "return $ do" block, which seems like it should get assigned to the constant "IO 0" the first time it is called and then never change? Instead, c seems to get assigned the function defined in that "return $ do" block, which is then executed anew every time printCounts gets to another "print =<< c." It seems that the answer somehow lies in newCounter having the double nested "IO (IO Int)" type, but I can't follow why that makes c a function to be re-executed when called instead of a constant evaluated just once.


Solution

  • You can think of IO as a type of programs. newCounter :: IO (IO Int) is a program that outputs a program. More precisely, newCounter allocates a new counter, and returns a program that, when run, increments the counter and returns its old value. newCounter doesn't execute the program it returns. It would if you wrote instead:

    newCounter :: IO (IO Int)
    newCounter = do 
      r <- newIORef 0
      let p = do              -- name the counter program p
            v <- readIORef r
            writeIORef r (v + 1)
            return v
      p          -- run the counter program once
      return p   -- you can still return it to run again later
    

    You can also use equational reasoning to unfold printCounts into a sequence of primitives. All versions of printCounts below are equivalent programs:

    -- original definition
    printCounts :: IO ()
    printCounts = do
      c <- newCounter
      print =<< c
      print =<< c
      print =<< c
    
    -- by definition of newCounter...
    
    printCounts = do
      c <- do
        r <- newIORef 0
        return $ do
          v <- readIORef r
          writeIORef r (v + 1)
          return v
      print =<< c
      print =<< c
      print =<< c
    
    -- by the monad laws (quite hand-wavy for brevity)
    -- do
    --   c <- do
    --     X
    --     Y
    --   .....
    -- =
    -- do
    --   X
    --   c <- 
    --     Y
    --   .....
    --
    -- (more formally,
    --  ((m >>= \x -> k x) >>= h) = (m >>= (\x -> k x >>= h)))
    
    printCounts = do
      r <- newIORef 0
      c <-
        return $ do
          v <- readIORef r
          writeIORef r (v + 1)
          return v
      print =<< c
      print =<< c
      print =<< c
    
    -- c <- return X
    -- =
    -- let c = X
    --
    -- (more formally, ((return X) >>= (\c -> k c)) = (k X)
    
    printCounts = do
      r <- newIORef 0
      let c = do
            v <- readIORef r
            writeIORef r (v + 1)
            return v
      print =<< c
      print =<< c
      print =<< c
    
    -- let-substitution
    
    printCounts = do
      r <- newIORef 0
      print =<< do
            v <- readIORef r
            writeIORef r (v + 1)
            return v
      print =<< do
            v <- readIORef r
            writeIORef r (v + 1)
            return v
      print =<< do
            v <- readIORef r
            writeIORef r (v + 1)
            return v
    
    -- after many more applications of monad laws and a bit of renaming to avoid shadowing
    -- (in particular, one important step is ((return v >>= print) = (print v)))
    
    printCounts = do
      r <- newIORef 0
      v1 <- readIORef r
      writeIORef r (v1 + 1)
      print v1
      v2 <- readIORef r
      writeIORef r (v2 + 1)
      print v2
      v3 <- readIORef r
      writeIORef r (v3 + 1)
      print v3
    

    In the final version, you can see that printCounts quite literally allocates a counter and increments it three times, printing each intermediate value.

    One key step is the let-substitution one, where the counter program gets duplicated, which is why it gets to run three times. let x = p; ... is different from x <- p; ..., which runs p, and binds x to the result rather than the program p itself.