Search code examples
pythonperformancenumpyscientific-computing

Speed up NumPy's where function


I am trying to extract the indices of all values of a 1D array of numbers that exceed some threshold. The array is on the order of 1e9 long.

My approach is the following in NumPy:

idxs = where(data>threshold) 

This takes something upwards of 20 mins, which is unacceptable. How can I speed this function up? Or, are there faster alternatives?

(To be specific, it takes that long on a Mac OS X running 10.6.7, 1.86 GHz Intel, 4GB RAM doing nothing else.)


Solution

  • Try a mask array. This creates a view of the same data.

    So the syntax would be:

     b=a[a>threshold]
    

    b is not a new array (unlike where) but a view of a where the elements meet the boolean in the index.

    Example:

    import numpy as np
    import time
    
    a=np.random.random_sample(int(1e9))
    
    t1=time.time()
    b=a[a>0.5]
    print(time.time()-t1,'seconds')
    

    On my machine, that prints 22.389815092086792 seconds


    edit

    I tried the same with np.where, and it is just as fast. I am suspicious: are you deleting these values from the array?