Search code examples
haskelllist-comprehensionprimesinfinite

Understanding Haskell code which applies the `tails` function to an infinite list within a list comprehension


After submitting my solution to Project Euler's problem 50 earlier today, I was scrolling through the problem's forums, taking a look at other folks' solutions/execution times.

After a while, I started to feel quite proud of my code which solved it in ~3 seconds (my code used the Primes library and was compiled with O2)...

...and then I came across the code below which solved it in ~0.05 seconds...in interpreted mode (i.e., ghci).

Can someone explain how/why the code below solves this particular problem?

The mind-twisting part is the application of the tails function to the infinite list of primes (primes) within a list comprehension. I'm having a hard time understanding how we guarantee that we look at all possible sublists of consecutive primes and not just those generated by tails.

(My usual strategy of trying bits and pieces of code in ghci doesn't work in this situation because primes is infinite...)

The problem: We're asked to find the largest prime number below 1,000,000 which is the result of summing consecutive prime numbers. For example, the largest prime number below 100 which is the sum of consecutive primes is 41 (2 + 3 + 5 + 7 + 11 + 13).

import Data.List (tails)
import Data.Numbers.Primes

under n xs = takeWhile (< n) xs

takeUntil p xs = foldr (\x r-> if p x then [x] else x:r) [] xs

res :: [((Int, Int), (Int, Int))] 
-- ((top_length, sums_to), (total_length, starting_prime))

res = [(r,(length s,x)) | (x:xs) <- tails primes
                        , let s = zip [1..]
                                $ under 100
                                $ scanl (+) x xs
                        , let r = ...] 

main = mapM_ print $ takeUntil ...

Solution

  • In general, there's no problem with taking tails of primes, because Haskell is lazy (the evaluation is on-demand).

    Here specifically there's even less problem because each sublist is trimmed with under 100 = takeWhile (< 100) - only a finite prefix is taken.

    (x:xs) <- tails primes just goes through all suffixes of primes - i.e. primes starting from x=2; then primes starting from x=3, then 5,7,11, .... The pattern only demands the head element x, and the tail xs of the primes list, and even their values aren't immediately requested, only the so-called "spine" of primes list is forced, 1 notch at a time (of course making sure primes starts with at least one element x will automatically compute the actual value of x, but that's another matter).

    So (x:xs) <- tails primes operates on consecutive suffixes of the primes list. Try tails [1..10] to see what's going on there.


    When you type

    GHCi> primes
    

    at the GHCi prompt, you're actually requesting all the elements of primes list to be printed as the output. But

    GHCi> under 100 primes
    

    will only request those below 100, and not more than one element after that. It uses the built-in takeWhile that examines elements in primes, one after another, until one is found that fails the predicate (in this case, is bigger than 100). The terminating element is not included in the result.

    The user-defined takeUntil differs from that only in that it also includes the terminating element in its result (and the meaning of predicate is flipped - it signals when to stop).

    The scanl (+) is the usual way to calculate the sequence of partial sums of a sequence:

    Prelude> scanl (+) 1 [2..10]
    [1,3,6,10,15,21,28,36,45,55]
    

    The code

     [ (r,(length s,x)) | (x:xs) <- tails primes
                        , let s = zip [1..]
                                $ under 100
                                $ scanl (+) x xs
                        , let r = last $ filter (isPrime.snd) s] 
    

    means:

    for each suffix of primes        -- [[2,3,5,7...],[3,5,7,11...],...]
      let 
        x  = head suffix             -- for all primes from x
        xs = tail suffix             --   i.e. xs
        t1 = scanl (+) x xs          -- calculate partial sums
        t2 = under 100 t1            -- stopping when the sum reaches 100
        s  = zip [1..] t2            -- index each sum by the length of subsequence,
        t3 = filter (isPrime.snd) s  -- keep only such that summed to a prime,
        r  = last t3                 -- and take the one created from the
      in                             --    longest subsequence, starting from x
        emit (r,(length s,x))        -- and collect the data
    

    So we get the list of entries ((top_length, sums_to), (total_length, starting_prime)) for each consecutive prime number, starting_prime = 2,3,5,7,11, ... .

    The takeUntil expression in main determines when it's okay to stop, as there's no possibility of improving the result anymore.