Search code examples
pythonnumpyloopsoptimization

How to make the nested for loop execute faster in python?


Here is my script:

for a in range(-100, 101):
    for b in range(-100, 101):
        for c in range(-100, 101):
            for d in range(-100, 101):
                if abs(2**a*3**b*5**c*7**d-0.3048) <= 10**(-6):
                    print('a=',a, ', b=', b, ', c=', c,', d=', d,', the number=', 2**a*3**b*5**c*7**d, ', error=', abs(2**a*3**b*5**c*7**d-.3048))

It took 27 mins and 15 seconds to execute the above script in python. I know that it goes through 201^4 expression evaluations, but I need to run these kinds of calculations faster (because I want to try range(-200,201) and so on).

I'm wondering if it is possible to make the above code execute faster. I think using numpy arrays would help, but not sure how to apply this, and whether it is actually effective.


Solution

  • For these kind of computations you can try numba JIT:

    from numba import njit
    
    
    @njit
    def fn():
        for a in range(-100, 101):
            for b in range(-100, 101):
                for c in range(-100, 101):
                    for d in range(-100, 101):
                        n = (2.0**a) * (3.0**b) * (5.0**c) * (7.0**d)
                        v = n - 0.3048
                        if abs(v) <= 1e-06:
                            print(
                                "a=",
                                a,
                                ", b=",
                                b,
                                ", c=",
                                c,
                                ", d=",
                                d,
                                ", the number=",
                                n,
                                ", error=",
                                abs(n - 3.048),
                            )
    
    
    fn()
    

    Running this code on my machine (AMD 5700X) takes ~57 seconds (that's with compilation step included). In comparison, without the @njit (just plain Python) this takes exactly 4 minutes.

    a= -78 , b= -89 , c= -14 , d= 89 , the number= 0.3047994427888104 , error= 2.7432005572111895
    a= -78 , b= -57 , c= 50 , d= 18 , the number= 0.30479915330101043 , error= 2.7432008466989894
    a= -69 , b= -85 , c= 87 , d= 0 , the number= 0.3047993420932106 , error= 2.7432006579067894
    a= -63 , b= 42 , c= -99 , d= 80 , the number= 0.3048005478488736 , error= 2.7431994521511265
    a= -63 , b= 74 , c= -35 , d= 9 , the number= 0.3048002583600241 , error= 2.743199741639976
    a= -54 , b= 14 , c= -62 , d= 62 , the number= 0.3048007366419375 , error= 2.7431992633580626
    a= -54 , b= 46 , c= 2 , d= -9 , the number= 0.30480044715290866 , error= 2.7431995528470914
    a= -54 , b= 78 , c= 66 , d= -80 , the number= 0.3048001576641548 , error= 2.7431998423358452
    a= -45 , b= -14 , c= -25 , d= 44 , the number= 0.30480092543511833 , error= 2.7431990745648815
    a= -45 , b= 18 , c= 39 , d= -27 , the number= 0.3048006359459102 , error= 2.7431993640540897
    a= -36 , b= -10 , c= 76 , d= -45 , the number= 0.30480082473902875 , error= 2.7431991752609712
    a= 5 , b= -44 , c= -72 , d= 82 , the number= 0.30479914163960603 , error= 2.743200858360394
    a= 14 , b= -72 , c= -35 , d= 64 , the number= 0.304799330431799 , error= 2.743200669568201
    a= 14 , b= -40 , c= 29 , d= -7 , the number= 0.3047990409441057 , error= 2.743200959055894
    a= 23 , b= -100 , c= 2 , d= 46 , the number= 0.30479951922410875 , error= 2.7432004807758914
    a= 23 , b= -68 , c= 66 , d= -25 , the number= 0.30479922973623635 , error= 2.7432007702637637
    a= 29 , b= 91 , c= -56 , d= -16 , the number= 0.30480014600271205 , error= 2.743199853997288
    a= 38 , b= 31 , c= -83 , d= 37 , the number= 0.30480062428444915 , error= 2.743199375715551
    a= 38 , b= 63 , c= -19 , d= -34 , the number= 0.30480033479552704 , error= 2.743199665204473
    a= 47 , b= 3 , c= -46 , d= 19 , the number= 0.30480081307756046 , error= 2.7431991869224395
    a= 47 , b= 35 , c= 18 , d= -52 , the number= 0.30480052358845894 , error= 2.743199476411541
    a= 56 , b= 7 , c= 55 , d= -70 , the number= 0.3048007123815079 , error= 2.7431992876184923
    a= 65 , b= -21 , c= 92 , d= -88 , the number= 0.3048009011746738 , error= 2.7431990988253263
    a= 97 , b= -27 , c= -93 , d= 57 , the number= 0.3047990292827057 , error= 2.7432009707172944
    
    real    0m57,939s
    user    0m0,009s
    sys     0m0,009s
    

    Looking at your code, you can use parallel range (prange) to speed up things even further:

    from numba import njit, prange
    
    
    @njit(parallel=True)
    def fn():
        for a in prange(-100, 101):
            i_a = 2.0**a
            for b in prange(-100, 101):
                i_b = i_a * 3.0**b
                for c in prange(-100, 101):
                    i_c = i_b * 5.0**c
                    for d in prange(-100, 101):
                        n = i_c * (7.0**d)
                        v = n - 0.3048
                        if abs(v) <= 1e-06:
                            print(
                                "a=",
                                a,
                                ", b=",
                                b,
                                ", c=",
                                c,
                                ", d=",
                                d,
                                ", the number=",
                                n,
                                ", error=",
                                abs(n - 3.048),
                            )
    
    
    fn()
    

    Takes on my 8C/16T machine just ~2.7 seconds.

    @EDIT: Added storing intermediate results. Thanks @yotheguitou