Search code examples
pythonprimespalindrome

Optimizing algorithm for finding palindromic primes in Python


I've been trying to solve a LeetCode problem which takes an input number (less than 10^8) and returns the next palindromic prime. Also, the answer is guaranteed to exist and is less than 2 * 10^8. My approach seems to work fine for most numbers, but the runtime increases significantly and LeetCode tells me I've exceeded the time limit when a specific number is entered (like 9989900). Is it because the gap between palindromic primes is large in that range? This is the code I've written.

    import time

    start = time.time()


    def is_prime(num: int) -> bool:
        if num < 2:
            return False
        elif num == 2 or num == 3:
            return True
        if num % 6 != 1 and num % 6 != 5:
            return False
        else:
            for i in range(3, int(num ** 0.5) + 1, 2):
                if num % i == 0:
                    return False
            else:
                return True


    def is_palindrome(num: int) -> bool:
        return str(num) == str(num)[::-1]


    class Solution:
        def primePalindrome(self, N: int):
            if N == 1:
                return 2
            elif 8 <= N < 11:
                return 11

            elif is_prime(N) and is_palindrome(N):
                return N

            # To skip even numbers, take 2 cases, i.e., when N is even and when N is odd
            elif N % 2 == 0:
                for i in range(N + 1, 2 * 10 ** 8, 2):
                    if len(str(i)) % 2 == 0:  # Because even length palindromes are divisible by 11
                        i += 2
                    elif is_palindrome(i):
                        if is_prime(i):
                            return i
                        else:
                            continue

            else:
                for i in range(N, 2 * 10 ** 8, 2):
                    if len(str(i)) % 2 == 0:
                        i += 2
                    elif is_palindrome(i):
                        if is_prime(i):
                            return i
                        else:
                            continue


    obj = Solution()
    print(obj.primePalindrome(9989900))  # 100030001
    print(time.time() - start)  # 9 seconds

Is my solution slow because of too many loops and conditional statements? How do I reduce the runtime? Solving this without using any external libraries/packages would be preferable. Thank you.


Solution

  • Given that checking primes/palindromes sequentially isn't fast enough, I thought of this "number assembly" approach:

    Given that prime numbers can only end with digits 1, 3, 7 or 9. The palindrome numbers also can only begin with these digits. So, if we generate palindrom digits in between the first and last we will get a lot fewer numbers to chck for "primality".

    For example: 1xxxxxx1, 3xxxxxx3, 7xxxxxx7 and 9xxxxxx9

    These middle parts must also be palindromes so we only have half the digits to consider: 1xxxyyy1 where yyy is a mirror of xxx. For odd sized middle we will have xxzyy where zyy is a mirror of xxz.

    Combining this with a sequential generation of the first/last digits and digits in the middle, we can get the next number after N. By generating the most significant digits sequentially (i.e. the xxx part) we are certain that the composed numbers will be generated in an increasing sequence.

    def isPrime(n):
        return n==2 if n<3 or n%2==0 else all(n%d for d in range(3,int(n**0.5)+2,2))
    
    def nextPalPrime(N):
        digits = list(map(int,str(N)))
        while True:
            if digits[0] not in (1,3,7,9):              # advance first/last digits
                digits[0]  = [1,1,3,3,7,7,7,7,9,9][digits[0]]  
                digits[1:] = [0]*(len(digits)-1)
            digits[-1] = digits[0]
            midSize  = (len(digits)-1)//2
            midStart = int("".join(map(str,digits[1:1+midSize] or [0])))
            for middle in range(midStart,10**midSize):            # generate middle digits
                if midSize:
                    midDigits = list(map(int,f"{middle:0{midSize}}")) # left half of middle
                    digits[1:1+midSize]   = midDigits                 # set left half
                    digits[-midSize-1:-1] = midDigits[::-1]           # set mirrored right half
                number = int("".join(map(str,digits)))
                if number>N and isPrime(number):                  # check for prime
                    return number
            digits[0] += 1                                        # next first digit
            if digits[0] > 9: digits = [1]+[0]*len(digits)        # need more digits 
    

    output:

    pp = 1000
    for _ in range(20):
        pp = nextPalPrime(pp)
        print(pp)
    
    10301
    10501
    10601
    11311
    11411
    12421
    12721
    12821
    13331
    13831
    13931
    14341
    14741
    15451
    15551
    16061
    16361
    16561
    16661
    17471
    

    Performance:

    from time import time
    start=time()
    print(nextPalPrime(9989900),time()-start)
    
    100030001 0.023847103118896484
    

    No even number of digits

    Initially I was surprised that the solutions never produced a prime number with an even number of digits. but analyzing the composition of palindrome numbers I realized that those would always be multiples of 11 (so not prime):

    abba     = a*1001   + b*110   
             = a*11*91  + b*11*10
             = 11*(a*91 + b*10)
    
    abccba   = a*100001   + b*10010   + c*1100  
             = a*11*9091  + b*11*910  + c*11*100 
             = 11*(a*9091 + b*910     + c*100)
    
    abcddcba = a*10000001   + b*1000010   + c*100100  + d*110000
             = a*11*909091  + b*11*90910  + c*11*9100 + d*11*10000
             = 11*(a*909091 + b*90910     + c*9100    + d*10000)
    
    abcdeedcba = a*1000000001   + b*100000010  + c*10000100  + d*10010000  + e*11000000
               = a*11*90909091  + b*11*9090910 + c*11*909100 + d*11*910000 + e*11*1000000
               = 11*(a*90909091 + b*9090910    + c*909100    + d*910000    + e*1000000)
    

    Using this observation and a more numerical approach, we get a nice performance boost:

    def nextPalPrime(N):
        for width in range(len(str(N)),10):
            if width%2==0: continue
            size = width//2
            factors = [(100**(size-n)+1)*10**n for n in range(size)]+[10**size]
            for firstDigit in (1,3,7,9):
                if (firstDigit+1)*factors[0]<N: continue
                for middle in range(10**size):
                    digits = [firstDigit]+[*map(int,f"{middle:0{size}}")]
                    number = sum(f*d for f,d in zip(factors,digits))
                    if number>N and isPrime(number):
                        return number
    
    from time import time
    start=time()
    print(nextPalPrime(9989900),time()-start)
    
    100030001 0.004210948944091797