Search code examples
haskellmonad-transformersnon-deterministicmonadplus

Building a nondeterministic monad transformer in haskell


I would like to build a nondeterministic monad transformer in haskell that, I believe, behaves differently from ListT and from the alternative ListT proposed at http://www.haskell.org/haskellwiki/ListT_done_right. The first of these associates a monad with a list of items; the second associates a monad with individual items but has the property that monadic actions in given element influence monadic elements in subsequent slots of the list. The goal is to build a monad transformer of the form

data Amb m a = Cons (m a) (Amb m a) | Empty

so that every element of the list has its own monad associated with it and that successive elements have independent monads. At the end of this post I have a little demonstration of the kind of behavior this monad should give. If you know how to get some variant of ListT to give this behavior, that would be helpful too.

Below is my attempt. It is incomplete because the unpack function is undefined. How can I define it? Here's one incomplete attempt at defining it, but it doesn't take care of the case when the monad m contains an Empty Amb list:

unpack :: (Monad m) => m (Amb m a) -> Amb m a                                                                                                                 
unpack m = let first = join $ do (Cons x ys) <- m                                                                                                             
                                 return x                                                                                                                     
               rest =  do (Cons x ys) <- m                                                                                                                    
                          return ys                                                                                                                           
           in Cons first (unpack rest)   

Full (incomplete) code:

import Prelude hiding  (map, concat)                                                                                                                          
import Control.Monad                                                                                                                                          
import Control.Monad.Trans       

data Amb m a = Cons (m a) (Amb m a) | Empty                                                                                                                   

infixr 4 <:>                                                                                                                                                  
(<:>) = Cons                                                                                                                                                  

map :: Monad m => (a -> b) -> Amb m a -> Amb m b                                                                                                              
map f (Cons m xs) = Cons y (map f xs)                                                                                                                         
    where y = do a <- m                                                                                                                                       
                 return $ f a                                                                                                                                 
map f Empty = Empty                                                                                                                                           

unpack :: m (Amb m a) -> Amb m a                                                                                                                              
unpack m = undefined                                                                                                                                          


concat :: (Monad m) => Amb m (Amb m a) -> Amb m a                                                                                                             
concat (Cons m xs)  = (unpack m) `mplus` (concat xs)                                                                                                          
concat  Empty = Empty                                                                                                                                         

instance Monad m => Monad (Amb m) where                                                                                                                       
    return x = Cons (return x) Empty                                                                                                                          
    xs >>= f = let yss = map f xs                                                                                                                             
               in concat yss                                                                                                                                  

instance Monad m => MonadPlus (Amb m) where                                                                                                                   
    mzero = Empty                                                                                                                                             
    (Cons m xs) `mplus` ys = Cons m (xs `mplus` ys)                                                                                                           
    Empty `mplus` ys = ys                                                                                                                                     

instance MonadTrans Amb where                                                                                                                                 
    lift m = Cons m Empty        

Examples of desired behavior

Here, the base monad is State Int

instance Show a => Show (Amb (State Int) a) where                                                                                                             
    show m = (show .  toList) m                                                                                                                               


toList :: Amb (State Int) a -> [a]                                                                                                                            
toList Empty = []                                                                                                                                             
toList (n `Cons` xs) = (runState n 0 : toList xs)                                                                                                             


x = (list $ incr) >> (incr <:> incr <:> Empty)                                                                                                                
y = (list $ incr) >> (incr <:> (incr >> incr) <:> Empty)                                                                                                      

main = do                                                                                                                                                     
  putStr $ show x -- | should be [2, 2]                                                                                                                       
  putStr $ show y -- | should be [2, 3]   

Thanks.

Update: An example of why LogicT doesn't do what I want.

Here's what LogicT does on the simple example above:

import Control.Monad                                                                                                                                          
import Control.Monad.Logic                                                                                                                                    
import Control.Monad.State                                                                                                                                    

type LogicState = LogicT (State Int)                                                                                                                          


incr :: State Int Int                                                                                                                                         
incr = do i <- get                                                                                                                                            
          put (i + 1)                                                                                                                                         
          i' <- get                                                                                                                                           
          return i'                                                                                                                                           

incr' = lift incr                                                                                                                                             
y =  incr' >> (incr' `mplus` incr')                                                                                                                           

main = do                                                                                                                                                     
  putStrLn $ show (fst $ runState (observeAllT y) 0)   -- | returns [2,3], not [2,2]                                                                                                       

Solution

  • I believe you can just use StateT. For example:

    import Control.Monad.State
    
    incr = modify (+1)
    sample1 = incr `mplus` incr
    sample2 = incr `mplus` (incr >> incr)
    
    monomorphicExecStateT :: StateT Int [] a -> Int -> [Int]
    monomorphicExecStateT = execStateT
    
    main = do
        print (monomorphicExecStateT sample1 0) -- [1, 1]
        print (monomorphicExecStateT sample2 0) -- [1, 2]