Search code examples
haskellmonadsstate-monad

Using State Monad turns all of my functions into monadic functions


I write a cryptography library in Haskell to learn about cryptography and monads. (Not for real-world use!) The type of my function for primality testing is

prime :: (Integral a, Random a, RandomGen g) => a -> State g Bool

So as you can see I use the State Monad so I don't have the thread through the generator all the time. Internally the prime function uses the Miller-Rabin test, which rely on random numbers, which is why the prime function also must rely on random number. It makes sense in a way since the prime function only does a probabilistic test.

Just for reference, the entire prime function is below, but I don't think you need to read it.

-- | findDS n, for odd n, gives odd d and s >= 0 s.t. n=2^s*d.
findDS :: Integral a => a -> (a, a)
findDS n = findDS' (n-1) 0
  where
    findDS' q s
      | even q = findDS' (q `div` 2) (s+1)
      | odd  q = (q,s)

-- | millerRabinOnce n d s a does one MR round test on
-- n using a.
millerRabinOnce :: Integral a => a -> a -> a -> a -> Bool
millerRabinOnce n d s a
  | even n           = False
  | otherwise        = not (test1 && test2)
  where
    (d,s) = findDS n

    test1 = powerModulo a d n /= 1
    test2 = and $ map (\t -> powerModulo a ((2^t)*d) n /= n-1) 
                      [0..s-1]

-- | millerRabin k n does k MR rounds testing n for primality.
millerRabin :: (RandomGen g, Random a, Integral a) =>
  a -> a -> State g Bool
millerRabin k n = millerRabin' k
  where
    (d, s)          = findDS n
    millerRabin' 0 = return True
    millerRabin' k = do
      rest <- millerRabin' $ k - 1
      test <- randomR_st (1, n - 1)
      let this = millerRabinOnce n d s test
      return $ this && rest

-- | primeK k n. Probabilistic primality test of n
-- using k Miller-Rabin rounds.
primeK :: (Integral a, Random a, RandomGen g) => 
  a -> a -> State g Bool
primeK k n
  | n < 2            = return False
  | n == 2 || n == 3 = return True
  | otherwise        = millerRabin (min n k) n

-- | Probabilistic primality test with 64 Miller-Rabin rounds.
prime :: (Integral a, Random a, RandomGen g) => 
  a -> State g Bool
prime = primeK 64

The thing is, everywhere I need to use prime numbers, I have to turn that function into a monadic function too. Even where it's seemingly not any randomness involved. For example, below is my former function for recovering a secret in Shamir's Secret Sharing Scheme. A deterministic operation, right?

recover :: Integral a => [a] -> [a] -> a -> a
recover pi_s si_s q = sum prods `mod` q
  where
    bi_s  = map (beta pi_s q) pi_s
    prods = zipWith (*) bi_s si_s

Well that was when I used a naive, deterministic primality test function. I haven't rewritten the recover function yet, but I already know that the beta function relies on prime numbers, and hence it, and recover too, will. And both will have to go from simple non-monadic functions into two monadic function, even though the reason they use the State Monad / randomness is really deep down.

I can't help but think that all the code becomes more complex now that it has to be monadic. Am I missing something or is this always the case in situations like these in Haskell?

One solution I could think of is

prime' n = runState (prime n) (mkStdGen 123)

and use prime' instead. This solution raises two questions.

  1. Is this a bad idea? I don't think it's very elegant.
  2. Where should this "cut" from monadic to non-monadic code be? Because I also have functions like this genPrime:

_

genPrime :: (RandomGen g, Random a, Integral a) => a -> State g a
genPrime b = do
  n  <- randomR_st (2^(b-1),2^b-1)
  ps <- filterM prime [n..]
  return $ head ps

The question becomes whether to have the "cut" before or after genPrime and the like.


Solution

  • That is indeed a valid criticism of monads as they are implemented in Haskell. I don't see a better solution on the short term than what you mention, and switching all the code to monadic style is probably the most robust one, even though they are more heavyweight than the natural style, and indeed it can be a pain to port a large codebase, although it may pay off later if you want to add more external effects.

    I think algebraic effects can solve this elegantly, for examples:

    All functions are annotated with their effects a -> eff b, however, contrary to Haskell, they can all be composed simply like pure functions a -> b (which are thus a special case of effectful functions, with an empty effect signature). The language then ensures that effects form a semi-lattice so that functions with different effects can be composed.

    It seems difficult to have such a system in Haskell. Free(r) monads libraries allow composing types of effects in a similar way, but still require the explicit monadic style at the term level. One interesting idea would be to overload function application, so it can be implicitly changed to (>>=), but a principled way to do so eludes me. The main issue is that a function a -> m b is seen as both an effectful function with effects in m and codomain b, and as a pure function with codomain m b. How can we infer when to use ($) or (>>=)?

    In the particular case of randomness, I once had a somewhat related idea involving splittable random generators (shameless plug): https://blog.poisson.chat/posts/2017-03-04-splittable-generators.html