Search code examples
python-3.xoptimizationbit-manipulationprimessieve-of-eratosthenes

Speed up bitstring/bit operations in Python?


I wrote a prime number generator using Sieve of Eratosthenes and Python 3.1. The code runs correctly and gracefully at 0.32 seconds on ideone.com to generate prime numbers up to 1,000,000.

# from bitstring import BitString

def prime_numbers(limit=1000000):
    '''Prime number generator. Yields the series
    2, 3, 5, 7, 11, 13, 17, 19, 23, 29 ...    
    using Sieve of Eratosthenes.
    '''
    yield 2
    sub_limit = int(limit**0.5) 
    flags = [False, False] + [True] * (limit - 2)   
#     flags = BitString(limit)
    # Step through all the odd numbers
    for i in range(3, limit, 2):       
        if flags[i] is False:
#        if flags[i] is True:
            continue
        yield i
        # Exclude further multiples of the current prime number
        if i <= sub_limit:
            for j in range(i*3, limit, i<<1):
                flags[j] = False
#                flags[j] = True

The problem is, I run out of memory when I try to generate numbers up to 1,000,000,000.

    flags = [False, False] + [True] * (limit - 2)   
MemoryError

As you can imagine, allocating 1 billion boolean values (1 byte 4 or 8 bytes (see comment) each in Python) is really not feasible, so I looked into bitstring. I figured, using 1 bit for each flag would be much more memory-efficient. However, the program's performance dropped drastically - 24 seconds runtime, for prime number up to 1,000,000. This is probably due to the internal implementation of bitstring.

You can comment/uncomment the three lines to see what I changed to use BitString, as the code snippet above.

My question is, is there a way to speed up my program, with or without bitstring?

Edit: Please test the code yourself before posting. I can't accept answers that run slower than my existing code, naturally.

Edit again:

I've compiled a list of benchmarks on my machine.


Solution

  • There are a couple of small optimizations for your version. By reversing the roles of True and False, you can change "if flags[i] is False:" to "if flags[i]:". And the starting value for the second range statement can be i*i instead of i*3. Your original version takes 0.166 seconds on my system. With those changes, the version below takes 0.156 seconds on my system.

    def prime_numbers(limit=1000000):
        '''Prime number generator. Yields the series
        2, 3, 5, 7, 11, 13, 17, 19, 23, 29 ...
        using Sieve of Eratosthenes.
        '''
        yield 2
        sub_limit = int(limit**0.5)
        flags = [True, True] + [False] * (limit - 2)
        # Step through all the odd numbers
        for i in range(3, limit, 2):
            if flags[i]:
                continue
            yield i
            # Exclude further multiples of the current prime number
            if i <= sub_limit:
                for j in range(i*i, limit, i<<1):
                    flags[j] = True
    

    This doesn't help your memory issue, though.

    Moving into the world of C extensions, I used the development version of gmpy. (Disclaimer: I'm one of the maintainers.) The development version is called gmpy2 and supports mutable integers called xmpz. Using gmpy2 and the following code, I have a running time of 0.140 seconds. Running time for a limit of 1,000,000,000 is 158 seconds.

    import gmpy2
    
    def prime_numbers(limit=1000000):
        '''Prime number generator. Yields the series
        2, 3, 5, 7, 11, 13, 17, 19, 23, 29 ...
        using Sieve of Eratosthenes.
        '''
        yield 2
        sub_limit = int(limit**0.5)
        # Actual number is 2*bit_position + 1.
        oddnums = gmpy2.xmpz(1)
        current = 0
        while True:
            current += 1
            current = oddnums.bit_scan0(current)
            prime = 2 * current + 1
            if prime > limit:
                break
            yield prime
            # Exclude further multiples of the current prime number
            if prime <= sub_limit:
                for j in range(2*current*(current+1), limit>>1, prime):
                    oddnums.bit_set(j)
    

    Pushing optimizations, and sacrificing clarity, I get running times of 0.107 and 123 seconds with the following code:

    import gmpy2
    
    def prime_numbers(limit=1000000):
        '''Prime number generator. Yields the series
        2, 3, 5, 7, 11, 13, 17, 19, 23, 29 ...
        using Sieve of Eratosthenes.
        '''
        yield 2
        sub_limit = int(limit**0.5)
        # Actual number is 2*bit_position + 1.
        oddnums = gmpy2.xmpz(1)
        f_set = oddnums.bit_set
        f_scan0 = oddnums.bit_scan0
        current = 0
        while True:
            current += 1
            current = f_scan0(current)
            prime = 2 * current + 1
            if prime > limit:
                break
            yield prime
            # Exclude further multiples of the current prime number
            if prime <= sub_limit:
                list(map(f_set,range(2*current*(current+1), limit>>1, prime)))
    

    Edit: Based on this exercise, I modified gmpy2 to accept xmpz.bit_set(iterator). Using the following code, the run time for all primes less 1,000,000,000 is 56 seconds for Python 2.7 and 74 seconds for Python 3.2. (As noted in the comments, xrange is faster than range.)

    import gmpy2
    
    try:
        range = xrange
    except NameError:
        pass
    
    def prime_numbers(limit=1000000):
        '''Prime number generator. Yields the series
        2, 3, 5, 7, 11, 13, 17, 19, 23, 29 ...
        using Sieve of Eratosthenes.
        '''
        yield 2
        sub_limit = int(limit**0.5)
        oddnums = gmpy2.xmpz(1)
        f_scan0 = oddnums.bit_scan0
        current = 0
        while True:
            current += 1
            current = f_scan0(current)
            prime = 2 * current + 1
            if prime > limit:
                break
            yield prime
            if prime <= sub_limit:
                oddnums.bit_set(iter(range(2*current*(current+1), limit>>1, prime)))
    

    Edit #2: One more try! I modified gmpy2 to accept xmpz.bit_set(slice). Using the following code, the run time for all primes less 1,000,000,000 is about 40 seconds for both Python 2.7 and Python 3.2.

    from __future__ import print_function
    import time
    import gmpy2
    
    def prime_numbers(limit=1000000):
        '''Prime number generator. Yields the series
        2, 3, 5, 7, 11, 13, 17, 19, 23, 29 ...
        using Sieve of Eratosthenes.
        '''
        yield 2
        sub_limit = int(limit**0.5)
        flags = gmpy2.xmpz(1)
        # pre-allocate the total length
        flags.bit_set((limit>>1)+1)
        f_scan0 = flags.bit_scan0
        current = 0
        while True:
            current += 1
            current = f_scan0(current)
            prime = 2 * current + 1
            if prime > limit:
                break
            yield prime
            if prime <= sub_limit:
                flags.bit_set(slice(2*current*(current+1), limit>>1, prime))
    
    start = time.time()
    result = list(prime_numbers(1000000000))
    print(time.time() - start)
    

    Edit #3: I've updated gmpy2 to properly support slicing at the bit level of an xmpz. No change in performance but a much nice API. I have done a little tweaking and I've got the time down to about 37 seconds. (See Edit #4 to changes in gmpy2 2.0.0b1.)

    from __future__ import print_function
    import time
    import gmpy2
    
    def prime_numbers(limit=1000000):
        '''Prime number generator. Yields the series
        2, 3, 5, 7, 11, 13, 17, 19, 23, 29 ...
        using Sieve of Eratosthenes.
        '''
        sub_limit = int(limit**0.5)
        flags = gmpy2.xmpz(1)
        flags[(limit>>1)+1] = True
        f_scan0 = flags.bit_scan0
        current = 0
        prime = 2
        while prime <= sub_limit:
            yield prime
            current += 1
            current = f_scan0(current)
            prime = 2 * current + 1
            flags[2*current*(current+1):limit>>1:prime] = True
        while prime <= limit:
            yield prime
            current += 1
            current = f_scan0(current)
            prime = 2 * current + 1
    
    start = time.time()
    result = list(prime_numbers(1000000000))
    print(time.time() - start)
    

    Edit #4: I made some changes in gmpy2 2.0.0b1 that break the previous example. gmpy2 no longer treats True as a special value that provides an infinite source of 1-bits. -1 should be used instead.

    from __future__ import print_function
    import time
    import gmpy2
    
    def prime_numbers(limit=1000000):
        '''Prime number generator. Yields the series
        2, 3, 5, 7, 11, 13, 17, 19, 23, 29 ...
        using Sieve of Eratosthenes.
        '''
        sub_limit = int(limit**0.5)
        flags = gmpy2.xmpz(1)
        flags[(limit>>1)+1] = 1
        f_scan0 = flags.bit_scan0
        current = 0
        prime = 2
        while prime <= sub_limit:
            yield prime
            current += 1
            current = f_scan0(current)
            prime = 2 * current + 1
            flags[2*current*(current+1):limit>>1:prime] = -1
        while prime <= limit:
            yield prime
            current += 1
            current = f_scan0(current)
            prime = 2 * current + 1
    
    start = time.time()
    result = list(prime_numbers(1000000000))
    print(time.time() - start)
    

    Edit #5: I've made some enhancements to gmpy2 2.0.0b2. You can now iterate over all the bits that are either set or clear. Running time has improved by ~30%.

    from __future__ import print_function
    import time
    import gmpy2
    
    def sieve(limit=1000000):
        '''Returns a generator that yields the prime numbers up to limit.'''
    
        # Increment by 1 to account for the fact that slices do not include
        # the last index value but we do want to include the last value for
        # calculating a list of primes.
        sieve_limit = gmpy2.isqrt(limit) + 1
        limit += 1
    
        # Mark bit positions 0 and 1 as not prime.
        bitmap = gmpy2.xmpz(3)
    
        # Process 2 separately. This allows us to use p+p for the step size
        # when sieving the remaining primes.
        bitmap[4 : limit : 2] = -1
    
        # Sieve the remaining primes.
        for p in bitmap.iter_clear(3, sieve_limit):
            bitmap[p*p : limit : p+p] = -1
    
        return bitmap.iter_clear(2, limit)
    
    if __name__ == "__main__":
        start = time.time()
        result = list(sieve(1000000000))
        print(time.time() - start)
        print(len(result))