Search code examples
haskellstatemonadsmonad-transformersstatet

StateT and non-determinism monad: a simple example


As part of learning how to work with StateT and the nondeterminism monad, I'd like to write a function which uses these to enumerate the partitions of an integer (while being allowed to reuse integers). For example, passing an argument of 4 should result in [[1,1,1,1],[1,1,2],[2,2],[1,3],[4]] (uniqueness doesn't matter, I'm more concerned with just getting to working code).

(Also, I'm aware that there's a recursive solution for generating partitions as well as dynamic programming and generating function based solutions for counting partitions - the purpose of this exercise is to construct a minimal working example that combines StateT and [].)

Here's my attempt that was designed to work on any input less than or equal to 5:

{-# LANGUAGE NoImplicitPrelude #-}
{-# OPTIONS_GHC -Wall #-}

import CorePrelude
import Control.Monad.State.Lazy

sumState :: StateT Int [] [Int]
sumState = do
  m <- lift [1..5]
  n <- get <* modify (-m+)
  case compare n 0 of
    LT -> mzero
    EQ -> return [m]
    GT -> fmap (n:) sumState

runner :: Int -> [([Int],Int)]
runner = runStateT sumState

I'm using runStateT rather than evalStateT to help with debugging (it's helpful to see the final state values). Like I said, I'm not too worried about generating unique partitions since I'd first like to just understand the correct way to use these two monads together.

Loading it in GHCi and evaluating runner 4 results in the following and I'm confused as to why the above code produces this output.

[([4,3,2,1,1],-1),([4,3,2,1,2],-2),([4,3,2,1,3],-3),([4,3,2,1,4],-4),([4,3,2,1,5],-5),([4,3,2,1],-1),([4,3,2,2],-2),([4,3,2,3],-3),([4,3,2,4],-4),([4,3,2,5],-5),([4,3,1,1],-1),([4,3,1,2],-2),([4,3,1,3],-3),([4,3,1,4],-4),([4,3,1,5],-5),([4,3,1],-1),([4,3,2],-2),([4,3,3],-3),([4,3,4],-4),([4,3,5],-5),([4,2,1,1],-1),([4,2,1,2],-2),([4,2,1,3],-3),([4,2,1,4],-4),([4,2,1,5],-5),([4,2,1],-1),([4,2,2],-2),([4,2,3],-3),([4,2,4],-4),([4,2,5],-5),([4,1,1],-1),([4,1,2],-2),([4,1,3],-3),([4,1,4],-4),([4,1,5],-5),([4,1],-1),([4,2],-2),([4,3],-3),([4,4],-4),([4,5],-5)]

What am I doing wrong? What's the correct way to combine StateT and [] in order to enumerate partitions?


Solution

  • You just have two little mistakes. The first is here:

    n <- get <* modify (-m+)
    

    This gets the value of n before we subtract m. You almost certainly want

    n <- modify (-m+) >> get
    

    instead, or

    modify (-m+)
    n <- get
    

    if you prefer that spelling. The other is that you're putting the current state in the list instead of the value you're adding in the GT branch:

    GT -> fmap (n:) sumState
    

    Change that to

    GT -> fmap (m:) sumState
    

    and you're golden:

    *Main> runner 4
    [([1,1,1,1],0),([1,1,2],0),([1,2,1],0),([1,3],0),([2,1,1],0),([2,2],0),([3,1],0),([4],0)]