First off, apologies for the vague title, I couldn't think of an appropriate name for this issue.
I have 3 numpy arrays in the follwing formats:
N = ([[13, 14, 15], [2, 5, 7], [4, 6, 8] ... several hundred thousand elements long
e1 = [1, 0, 0]
e2 = [0, 1, 0]
The idea is to create a fourth array, 'v', which shall have the same dimensions as 'N', but will be given values based on an if statement. Here is what I currently have which should better explain the issue:
v = np.zeros([len(N), 3])
for i in range(0, len(N)):
if((N*e1)[i,0] != 0):
v[i] = np.cross(N[i],e1)
else:
v[i] = np.cross(N[i],e2)
This code does what I require it to but does so in a longer than anticipated time (> 5 mins). Is there any form of list comprehension or similar concept I could use to increase the efficiency of the code?
You can use numpy.where
to replace if-else and vectorize the process with broadcasting, here is an option with numpy.where
:
import numpy as np
np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
Some benchmarks here:
1) Data set up:
N = np.array([np.random.randint(0,10,3) for i in range(1000)])
N
#array([[3, 5, 0],
# [5, 0, 8],
# [4, 6, 0],
# ...,
# [9, 4, 2],
# [6, 9, 3],
# [2, 9, 2]])
e1 = np.array([1, 0, 0])
e2 = np.array([0, 1, 0])
2) Timing:
def forloop():
v = np.zeros([len(N), 3]);
for i in range(0, len(N)):
if((N*e1)[i,0] != 0):
v[i] = np.cross(N[i],e1)
else:
v[i] = np.cross(N[i],e2)
return v
def forloop2():
v = np.zeros([len(N), 3])
# Only calculate this one time.
my_product = N*e1
for i in range(0, len(N)):
if my_product[i,0] != 0:
v[i] = np.cross(N[i],e1)
else:
v[i] = np.cross(N[i],e2)
return v
%timeit forloop()
10 loops, best of 3: 25.5 ms per loop
%timeit forloop2()
100 loops, best of 3: 12.7 ms per loop
%timeit np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
10000 loops, best of 3: 71.9 µs per loop
3) Result checking for all methods:
v1 = forloop()
v2 = np.where(np.repeat(N[:,0] != 0, 3).reshape(1000,3), np.cross(N, e1), np.cross(N, e2))
v3 = forloop2()
(v3 == v1).all()
# True
(v1 == v2).all()
# True