Search code examples
haskellrecursionprobabilitycategory-theoryrecursion-schemes

How can I use a recursion scheme to express this probability distribution in Haskell


This question is part theory / part implementation. Background assumption: I'm using the monad-bayes library to represent probability distributions as monads. A distribution p(a|b) can be represented as a function MonadDist m => b -> m a.

Suppose I have a conditional probability distribution s :: MonadDist m => [Char] -> m Char. I want to get a new probability distribution sUnrolled :: [Char] -> m [Char], defined mathematically (I think) as:

sUnrolled(chars|st) = 
              | len(chars)==1 -> s st
              | otherwise -> s(chars[-1]|st++chars[:-1]) * sUnrolled(chars[:-1]|st)

Intuitively it's the distribution you get by taking st :: [Char], sampling a new char c from s st, feeding st++[c] back into s, and so on. I believe iterateM s is more or less what I want. To make it a distribution we could actually look at, let's say that if we hit a certain character, we stop. Then iterateMaybeM works.

Theory Question: For various reasons, it would be really useful if I could express this distribution in more general terms, for instance in a way that generalized to the stochastic construction of a tree given a stochastic coalgebra. It looks like I have some sort of anamorphism here (I realize that the mathematical definition looks like a catamorphism, but in code I want to build up strings, not deconstruct them into probabilities) but I can't quite work out the details, not least because of the presence of the probability monad.

Practical Question: it would also be useful to implement this in Haskell in a way that used the recursion schemes library, for instance.


Solution

  • I'm not smart enough to thread monads through the recursion schemes, so I relied on recursion-schemes-ext, which has the anaM function for running anamorphisms with monadic actions attached.

    I did a (really ugly) proof of concept here:

    {-# LANGUAGE FlexibleContexts #-}
    import Data.Functor.Foldable (ListF(..), Base, Corecursive)
    import Data.Functor.Foldable.Exotic (anaM)
    import System.Random
    
    s :: String -> IO (Maybe Char)
    s st = do
      continue <- getStdRandom $ randomR (0, 2000 :: Int)
      if continue /= 0
        then do
        getStdRandom (randomR (0, length st - 1)) >>= return . Just . (st !!)
        else return Nothing
    
    
    result :: (Corecursive t, Traversable (Base t), Monad m) => (String -> m (Base t String)) -> String -> m t
    result f = anaM f
    
    example :: String -> IO (Base String String)
    example st = maybe Nil (\c -> Cons c $ c:st) <$> s st
    
    final :: IO String
    final = result example "asdf"
    
    main = final >>= print
    

    A couple of notes

    1. I mocked out your s function, since I'm not familiar with monad-bayes
    2. Since our final list is inside a monad, we have to construct it strictly. This forces us to make a finite list (I allowed my s function to randomly stop at around 2000 characters).

    EDIT:

    Below is a modified version that confirms that other recursive structures (in this case, a binary tree) can be spawned by the result function. Note the type of final and the value of example are the only two bits of the previous code that have changed.

    {-# LANGUAGE FlexibleContexts, TypeFamilies #-}
    import Data.Functor.Foldable (ListF(..), Base, Corecursive(..))
    import Data.Functor.Foldable.Exotic (anaM)
    import Data.Monoid
    import System.Random
    
    data Tree a = Branch a (Tree a) (Tree a) | Leaf
      deriving (Show, Eq)
    data TreeF a b = BranchF a b b | LeafF
    
    type instance Base (Tree a) = TreeF a
    instance Functor Tree where
      fmap f (Branch a left right) = Branch (f a) (f <$> left) (f <$> right)
      fmap f Leaf = Leaf
    instance Functor (TreeF a) where
      fmap f (BranchF a left right) = BranchF a (f left) (f right)
      fmap f LeafF = LeafF
    instance Corecursive (Tree a) where
      embed LeafF = Leaf
      embed (BranchF a left right) = Branch a left right
    instance Foldable (TreeF a) where
      foldMap f LeafF = mempty
      foldMap f (BranchF a left right) = (f left) <> (f right)
    instance Traversable (TreeF a) where
      traverse f LeafF = pure LeafF
      traverse f (BranchF a left right) = BranchF a <$> f left <*> f right
    
    s :: String -> IO (Maybe Char)
    s st = do
      continue <- getStdRandom $ randomR (0, 1 :: Int)
      if continue /= 0
        then getStdRandom (randomR (0, length st - 1)) >>= return . Just . (st !!)
        else return Nothing
    
    
    result :: (Corecursive t, Traversable (Base t), Monad m) => (String -> m (Base t String)) -> String -> m t
    result f = anaM f
    
    example :: String -> IO (Base (Tree Char) String)
    example st = maybe LeafF (\c -> BranchF c (c:st) (c:st)) <$> s st
    
    final :: IO (Tree Char)
    final = result example "asdf"
    
    main = final >>= print