Search code examples
pythonpython-3.xgeneratorsicplazy-sequences

Creating a generator for primes in the style of SICP


In a future course, I'll be having a discipline that uses Python with an emphasis of using sequences and generators and that kind of stuff inn Python.

I've been following an exercise list to exercise these parts. I'm stuck on an exercise that asks for a prime generator. Up until now, I haven't used Python very much, but I've read and done most of the exercises in SICP. There, they present the following program that makes use of the sieve of Eratosthenes to generate a lazy list of primes.

(define (sieve stream)
  (cons-stream
   (stream-car stream)
   (sieve (stream-filter
           (lambda (x)
             (not (divisible? x (stream-car stream))))
           (stream-cdr stream)))))

(define primes (sieve (integers-starting-from 2)))

In Python, from what I have read, the closest thing is generators So I tried translating it to the following.

import itertools
def sieve(seq):
    n = next(seq)
    yield n
    sieve(filter(lambda x: x % n != 0, seq))

def primes():
    return sieve(itertools.count(2))

print(list(itertools.islice(primes(),10)))

But it prints only [2]. I figure that this is because the result of the recursive call to sieve is just discarded, instead of running the function again, as I first expected.

To try to remedy this, I tried using a loop instead:

def sieve(seq):
    def divisible(n):
        return lambda x: x % n != 0
    while True:
        n = next(seq)
        yield n
        seq = sieve(filter(divisible(n), seq))

This works insofar as that I can generate the first 9 primes, but if I ask for the tenth a RecursionError is raised.

So, my question is how can I improve this to be able to calculate larger primes?

PS: There is already a proposed implementation of a sieve generator in https://stackoverflow.com/a/568618/6571467, but it explicitly deals with the previous primes in the sieve. Whereas in the lazy list paradigm, the objective is to abstract from the order the operations are actually executed.


Solution

  • For you first version, you can just use yield from to yield from the recursive call:

    def sieve(seq):
        n = next(seq)
        yield n
        yield from sieve(filter(lambda x: x % n != 0, seq))
    

    (or for x in sieve(...): yield x for older versions of Python)

    For your looped version, remove the recursive call, just stack the filters:

    def sieve(seq):
        while True:
            n = next(seq)
            yield n
            seq = filter(lambda x, n=n: x % n != 0, seq)
    

    Both versions will work for the first (almost) 1000 primes before also resulting in a maximum-recursion error (even with the loop, as you have a bunch of nested filter functions), which can be postponed by setting a higher recursion limit, but not really prevented -- except by not using recursion, or a language that supports Tail Call Optimization.

    Alternatively, for a purely iterative solution, you can store the set of seen primes and check whether any of those is a divisor. (Both recursive variants basically also store this set of primes, except they hide it in the stack of nested filter calls.)

    def sieve(seq):
        ps = []
        while True:
            n = next(seq)
            if not any(n % p == 0 for p in takewhile(lambda p: p*p <= n, ps)):
                yield n
                ps.append(n)
    

    All three versions yield the same results, but the "recursion-less" are (much) faster:

    >>> %timeit primes(sieve1, 900)
    1 loop, best of 5: 297 ms per loop
    
    >>> %timeit primes(sieve2, 900)
    1 loop, best of 5: 185 ms per loop
    
    >>> %timeit primes(sieve3, 900)
    10 loops, best of 5: 35.4 ms per loop
    

    (Using n.__rmod__ instead of lambda x: x % n != 0 gives a nice boost to the filter-based ones, but they are still much slower.)


    Addendum, about your second approach resulting in a recursion error even for very small valus: I am still having trouble wrapping my head around this, but here is how I understand it: By doing seq = sieve(filter(nondivisible(n), seq)) instead of just seq = filter(nondivisible(n), seq), the "top-level" sieve gets the next value from the sieve one level below, and so forth, and each of those adds another layer of sieves in each iteration, causing the height of the sieve-stack to double in each iteration.