Search code examples
pythonnumpydata-processing

How to select and threshold x% of values from a multi-dim NumPy array?


I have a multi-dimensional NumPy array with the shape (32, 128, 128). For each entry in this array (which is of shape (128, 128), I would like to check if 80% of the values present inside it are greater than a threshold, say 0.5.

Currently, I am doing something like so:

for entry in entries: # entries: (32, 128, 128)
    raveled = np.ravel(entry) # entry: (128, 128)
    total_sum = (raveled > 0.5).sum()
    proportion = total_sum/len(raveled)

    if proportion > 0.8:
        ...

I cannot seem to figure out an efficient way to do this. Any help would be appreciated.


Solution

  • x =  np.random.rand(32, 128, 128)
    #check 80%
    np.sum(x > 0.5, axis = (1, 2)) > 0.8 * 128 * 128
    

    x > 0.5 will return True/False boolean for all values (32 * 128 * 128). After that we are summing over 1st and 2nd axis (128 * 128) to extract the total number of True values i.e. where the conditions are met for all 32 arrays and checking whether the number is more than 80%.