Search code examples
pythonnumpysortingargmax

How to detect a tie in a numpy array when using argmax


If I have an array like below, how can I detect that there is a tie of at least 3 or more values when using np.argmax()?

examp = np.array([[4, 0, 1, 4, 4],
                  [5, 5, 1, 5, 5],
                  [1, 2, 2, 4, 1],
                  [4, 6, 1, 2, 4],
                  [1, 4, 3, 3, 3]])

np.argmax(examp, axis=1)

which gives an output:

array([0, 0, 3, 1, 1]

Taking the first row as an example, there is a "3-way tie". 3 values of 4. np.argmax returns the first index that has the max value. But, how can I detect that there is a "3-way tie" going on and have it decide the tie breaker with a custom function (on the condition that there is at least a "3-way tie" occurring?

So, first row: sees that there is a "3-way tie" of 4s. Custom function runs so that it can decide the tie-breaker.

Second row: "4-way tie" same thing happens.

Third row: only "2-way tie" which is less than condition of at least a "3-way tie". Can default to np.argmax.


Solution

  • One way for finding the n-th maximum is np.partition (or np.argpartition). In this case you can do something like this:

    >>> n = 3  # Size of tie
    >>> i = examp.argpartition([-n, -1], axis=-1)
    

    The values in the third-to-last and last columns are guaranteed to be in the correct sort order (and therefore the second-to-last as well, but only in this limited case). If those two values are equal to each other, then you have a 3-way tie:

    >>> r = np.arange(examp.shape[0])
    >>> examp[r, i[:, -n]] == examp[r, i[:, -1]]
    array([ True,  True, False, False, False])
    

    You can also use np.diff to compute the mask:

    >>> np.diff(examp[r[:, None], i[:, [-n, -1]]], axis=1) == 0
    array([[ True],
           [ True],
           [False],
           [False],
           [False]])
    

    You can get a similar result by using np.take_along_axis instead of the first index r:

    >>> np.diff(np.take_along_axis(examp, i[:, -n::n-1], 1), axis=1) == 0
    array([[ True],
           [ True],
           [False],
           [False],
           [False]])
    

    In all these cases, the value of argmax is just i[:, -1], since that's the index of the maximum value in the array.

    Since you are already using numpy, I highly recommend that you vectorize the custom tie-breaking function as well. I've provided the output as a mask here so that you can do exactly that as efficiently as possible.