Search code examples
pythonnumpyvectorizationbroadcasting

Count "greater" occurrences without loop


Let's assume that we have an array A if shape (100,) and B of shape (10,). Both contain values in [0,1].

How do we get the count of elements in A greater than each value in B? I expect an of shape (10,), where the first element is "how many in A are greater than B[0]", the second is "how many in A are greater than B[1]", etc ...

Without using loops.

I tried the following, but it didn't work :

import numpy as np
import numpy.random as rdm

A = rdm.rand(100)
B = np.linspace(0,1,10)
def occ(z: float) ->float:
     return np.count_nonzero(A > z)
occ(B)

Python won't use my function as a scalar function on B, that's why I get:

operands could not be broadcast together with shapes (10,) (100,) 

I've also tried with np.greater but I've got the same issue ...


Solution

  • Slow But Simple

    The error message is cryptic if you don't understand it, but it's telling you what to do. Array dimensions are broadcast together by lining them up starting with the right edge. This is especially helpful if you split your operation into two parts:

    1. Create a (100, 10) mask showing which elements of A are greater than which elements of B:

       mask = A[:, None] > B
      
    2. Sum the result of the previous operation along the axis corresponding to A:

       result = np.count_nonzero(mask, axis=0)
      

      OR

       result = np.sum(mask, axis=0)
      

    This can be written as a one-liner:

    (A[:, None] > B).sum(0)
    

    OR

    np.count_nonzero(A[:, None] > B, axis=0)
    

    You can switch the dimensions and place B in the first axis to get the same result:

    (A > B[:, None]).sum(1)
    

    Fast and Elegant

    Taking a totally different (but likely much more efficient) approach, you can use np.searchsorted:

    A.sort()
    result = A.size - np.searchsorted(A, B)
    

    By default, searchsorted returns the left-index that each element of B would be inserted into A at. That pretty much immediately tells you how many elements of A are greater than that.

    Benchmarks

    Here, the algos are labeled as follows:

    • B0: (A[:, None] > B).sum(0)
    • B1: (A > B[:, None]).sum(1)
    • HH: np.cumsum(np.histogram(A, bins=B)[0][::-1])[::-1]
    • SS: A.sort(); A.size - np.searchsorted(A, B)
    +--------+--------+----------------------------------------+
    | A.size | B.size |        Time (B0 / B1 / HH / SS)        |
    +--------+--------+----------------------------------------+
    |    100 |     10 |  20.9 µs / 15.7 µs / 68.3 µs / 8.87 µs |
    +--------+--------+----------------------------------------+
    |   1000 |     10 |   118 µs / 57.2 µs /  139 µs / 17.8 µs |
    +--------+--------+----------------------------------------+
    |  10000 |     10 |   987 µs /  288 µs / 1.23 ms /  131 µs |
    +--------+--------+----------------------------------------+
    | 100000 |     10 |  9.48 ms / 2.77 ms / 13.4 ms / 1.42 ms |
    +--------+--------+----------------------------------------+
    |    100 |    100 |  70.7 µs / 63.8 µs /   71 µs / 11.4 µs |
    +--------+--------+----------------------------------------+
    |   1000 |    100 |   518 µs /  388 µs /  148 µs / 21.6 µs |
    +--------+--------+----------------------------------------+
    |  10000 |    100 |  4.91 ms / 2.67 ms / 1.22 ms /  137 µs |
    +--------+--------+----------------------------------------+
    | 100000 |    100 |  57.4 ms / 35.6 ms / 13.5 ms / 1.42 ms |
    +--------+--------+----------------------------------------+
    

    Memory layout matters. B1 is always faster than B0. This happens because summing contiguous (cached) elements (along the last axis in C-order) is always faster than having to skip across rows to get the next element. Broadcasting performs well for small values of B. Keep in mind that both the time and space complexity for B0 and B1 is O(A.size * B.size). The complexity of the two histogramming solutions should be about O(A.size * log(A.size)), but SS is implemented much more efficiently than HH because it can assume more things about the data.