Search code examples
pythonarraysnumpycross-product

Python - how to speed up a for loop creating a numpy array from another numpy array calculation


First off, apologies for the vague title, I couldn't think of an appropriate name for this issue.

I have 3 numpy arrays in the follwing formats:

N = ([[13, 14, 15], [2, 5, 7], [4, 6, 8] ... several hundred thousand elements long

e1 = [1, 0, 0]

e2 = [0, 1, 0]

The idea is to create a fourth array, 'v', which shall have the same dimensions as 'N', but will be given values based on an if statement. Here is what I currently have which should better explain the issue:

v = np.zeros([len(N), 3])    

for i in range(0, len(N)):
    if((N*e1)[i,0] != 0):
        v[i] = np.cross(N[i],e1)
    else:
        v[i] = np.cross(N[i],e2)

This code does what I require it to but does so in a longer than anticipated time (> 5 mins). Is there any form of list comprehension or similar concept I could use to increase the efficiency of the code?


Solution

  • You can use numpy.where to replace if-else and vectorize the process with broadcasting, here is an option with numpy.where:

    import numpy as np
    np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
    

    Some benchmarks here:

    1) Data set up:

    N = np.array([np.random.randint(0,10,3) for i in range(1000)])
    N
    
    #array([[3, 5, 0],
    #       [5, 0, 8],
    #       [4, 6, 0],
    #       ..., 
    #       [9, 4, 2],
    #       [6, 9, 3],
    #       [2, 9, 2]])
    
    e1 = np.array([1, 0, 0])
    e2 = np.array([0, 1, 0])
    

    2) Timing:

    def forloop():
        v = np.zeros([len(N), 3]);    
    ​
        for i in range(0, len(N)):
            if((N*e1)[i,0] != 0):
                v[i] = np.cross(N[i],e1)
            else:
                v[i] = np.cross(N[i],e2)
        return v
    
    def forloop2():
        v = np.zeros([len(N), 3])    
    ​
        # Only calculate this one time.
        my_product = N*e1
    ​
        for i in range(0, len(N)):
            if my_product[i,0] != 0:
                v[i] = np.cross(N[i],e1)
            else:
                v[i] = np.cross(N[i],e2)               
        return v
    
    %timeit forloop()
    10 loops, best of 3: 25.5 ms per loop
    
    %timeit forloop2()
    100 loops, best of 3: 12.7 ms per loop    
    
    %timeit np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
    10000 loops, best of 3: 71.9 µs per loop
    

    3) Result checking for all methods:

    v1 = forloop()   
    
    v2 = np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
    
    v3 = forloop2()
    
    (v3 == v1).all()
    # True
    
    (v1 == v2).all()
    # True