Search code examples
pythonpython-3.xnumpyif-statementcpu-speed

Rewritning if condition to speed up in python


I have following piece of code with if statement within a function. When I run it would take long time and it is a way to rewrite if condition or a way to speed up this sample code?

import numpy as np

def func(S, R, H):
    ST =  S * R
    if ST <= - H:
      result = - H
    elif ST >= - H and ST < 0:
      result = ST
    else:
      result = min(ST, H)
    return result

y=[]
t1= time()
for x in np.arange(0, 10000, 0.001): 
    y.append(func(3, 5, x))
t2 = time()
print("time with numpy arange:", t2-t1)

time taken to run the code:

 10 s

This is reproduced sample of real code, and in the real code ST becomes both negative and positive value, we may keep conditions but changing if statement to something else may help perfom task faster!


Solution

  • If you want your functions parameters still available you would need to use boolean indexing in a creative way and replace your function with that:

    from time import time
    import numpy as np
    
    ran = np.arange(-10, 10, 1)
    s = 2
    r = 3
    
    st =  s * r
    
    def func(S, R, H):
        ST =  S * R
        if ST <= - H:
          result = - H
        elif ST >= - H and ST < 0:
          result = ST
        else:
          result = min(ST, H)
        return result
    
    # calculate with function
    a = []
    t1 = time()
    for x in ran:
        a.append(func(s, r, x))
    t2 = time()
    print("time with function:", t2 - t1)
    a = np.array(a)
    
    # calculate with numpy
    y = np.copy(ran)
    neg_y = np.copy(y) * -1
    
    # creative boolean indexing
    t1 = time()
    y[st <= neg_y] = neg_y[st <= neg_y]
    if st < 0:
      y[st >= neg_y] = st
    else:
      alike = np.full(ran.shape, st)[st >= neg_y]
      y[st > neg_y] = np.where(y[st > neg_y] > st, st, y[st > neg_y])
    t2 = time()
    
    print(a)
    print(y)
    print("time with numpy indexing:", t2 - t1)
    

    Will give you (timings omitted):

    # s=2, r=3
    [10  9  8  7  6  5  4  3  2  1  0 -1 -2 -3 -4 -5 -6 -6 -6 -6] # function
    [10  9  8  7  6  5  4  3  2  1  0 -1 -2 -3 -4 -5 -6 -6 -6 -6] # numpy
    
    # s=-2, s=3
    [10  9  8  7  6 -5 -4 -3 -2 -1  0  1  2  3  4  5  6  6  6  6] # function
    [10  9  8  7  6 -5 -4 -3 -2 -1  0  1  2  3  4  5  6  6  6  6] # numpy
    

    You might need to tweak it a bit more.

    Using a

    ran = np.arange(-1000, 1000, 0.001)
    

    I get timings (s=3,r=5) of:

    time with function: 5.606577634811401
    [1000.     999.999  999.998 ...   15.      15.      15.   ]
    [1000.     999.999  999.998 ...   15.      15.      15.   ]
    time with numpy indexing: 0.06600046157836914