Search code examples
performancehaskellmathnumber-theory

Speed up calculation of partitions in Haskell


I'm trying to solve Euler problem 78, which basically asks for the first number where the partition function p(n) is divisible by 1000000.

I use Euler's recursive fomula based on pentagonal numbers (calculated here in pents together with the proper sign). Here is my code:

ps = 1 : map p [1..] where
  p n = sum $ map getP $ takeWhile ((<= n).fst) pents where
    getP (pent,sign) = sign * (ps !! (n-pent)) 

pents = zip (map (\n -> (3*n-1)*n `div` 2) $ [1..] >>= (\x -> [x,-x]))
            (cycle [1,1,-1,-1])

While ps seems to produce the correct results, it is too slow. Is there a way to speed the calculation up, or do I need a completely different approach?


Solution

  • xs !! n has a linear complexity. You should rather try using a logarithmic or constant-access data structure.

    Edit : here is a quick implementation I came up with by copying a similar one by augustss :

    psOpt x = psArr x
      where psCall 0 = 1
            psCall n = sum $ map getP $ takeWhile ((<= n).fst) pents where
              getP (pent,sign) = sign * (psArr (n-pent))
            psArr n = if n > ncache then psCall n else psCache ! n
            psCache = listArray (0,ncache) $ map psCall [0..ncache]
    

    In ghci, I observe no spectacular speedup over your list version. No luck !

    Edit : Indeed, with -O2 as suggested by Chris Kuklewicz, this solution is eight times faster than your for n=5000. Combined with Hammar insight of doing sums modulo 10^6, I get a solution that is fast enough (find the hopefully correct answer in about 10 seconds on my machine):

    import Data.List (find)
    import Data.Array 
    
    ps = 1 : map p [1..] where
      p n = sum $ map getP $ takeWhile ((<= n).fst) pents where
        getP (pent,sign) = sign * (ps !! (n-pent)) 
    
    summod li = foldl (\a b -> (a + b) `mod` 10^6) 0 li
    
    ps' = 1 : map p [1..] where
      p n = summod $ map getP $ takeWhile ((<= n).fst) pents where
        getP (pent,sign) = sign * (ps !! (n-pent)) 
    
    ncache = 1000000
    
    psCall 0 = 1
    psCall n = summod $ map getP $ takeWhile ((<= n).fst) pents
      where getP (pent,sign) = sign * (psArr (n-pent))
    psArr n = if n > ncache then psCall n else psCache ! n
    psCache = listArray (0,ncache) $ map psCall [0..ncache]
    
    pents = zip (map (\n -> ((3*n-1)*n `div` 2) `mod` 10^6) $ [1..] >>= (\x -> [x,-x]))
                (cycle [1,1,-1,-1])
    

    (I broke the psCache abstraction, so you should use psArr instead of psOpt; this ensures that different call to psArr will reuse the same memoized array. This is useful when you write find ((== 0) . ...)... well, I thought it was better not to publish the complete solution.)

    Thanks to all for the additional advice.