Search code examples
haskellmonadsstate-monad

How can I write this simple code using the state monad?


I'm a beginner at Haskell and I've come across a situation where I would like to use the state monad. (Or at least, I think I that's what I'd like to use.) There are a million tutorials for the state monad, but all of them seem to assume that my main goal is to understand it on a deep conceptual level, and consequently they stop just before the part where they say how to actually develop software with it. So I'm looking for help with a simplified practical example.

Below is a very simple version of what my current code looks like. As you can see, I'm threading state through my functions, and my question is simply how to re-write the code using the do notation so that I won't have to do that.

data Machine = Register Int

addToState :: Machine -> Int -> Machine
addToState (Register s) a = Register $ s+a

subtractFromState :: Machine -> Int -> Machine
subtractFromState (Register s) a = Register (s-a)

getValue :: Machine -> Int
getValue (Register s) = s

initialState = Register 0

runProgram = getValue (subtractFromState (addToState initialState 6) 4)

The code simulates a simple abstract machine that has a single register, and instructions to add to the register, subtract from it, and get its value. The "program" at the end initialises the register to 0, adds 6 to it, subtracts 4 and returns the result, which of course is 2.

I understand the purpose of the state monad (or at least think I do), and I expect that it will allow me to re-write this so that I end up with something like

runProgram :: ???????
runProgram = do
    put 0
    addToState 6
    subtractFromState 4
    value <- getValue
    return value

However, despite all the tutorials I've read I still don't quite know how to transform my code into this form.

Of course, my actual machine's state is much more complicated, and I'm also passing around its output (which will be passed to another machine) and various other things, so I'm quite keen to simplify it. Knowing how to do it for this simplified example would be a very great help.

Update: after Lee's great answer I now know how to do this, but I'm stuck on how to write code in the same elegant form when I have multiple interacting machines. I've asked about that in a new question.


Solution

  • First you need to convert your existing functions to return State Machine a values:

    import Control.Monad.State.Lazy
    
    data Machine = Register Int
    
    addToState :: Int -> State Machine ()
    addToState i = do
            (Register x) <- get
            put $ Register (x + i)
    
    subtractFromState :: Int -> State Machine ()
    subtractFromState i = do
            (Register x) <- get
            put $ Register (x - i)
    
    getValue :: State Machine Int
    getValue = do
            (Register i) <- get
            pure i
    

    then you can combine them into a stateful computation:

    program :: State Machine Int
    program = do
      addToState 6
      subtractFromState 4
      getValue
    

    finally you need can run this computation with evalState to get the final result and discard the state:

    runProgram :: Int
    runProgram = evalState program (Register 0)