If I have an array like below, how can I detect that there is a tie of at least 3 or more values when using np.argmax()
?
examp = np.array([[4, 0, 1, 4, 4],
[5, 5, 1, 5, 5],
[1, 2, 2, 4, 1],
[4, 6, 1, 2, 4],
[1, 4, 3, 3, 3]])
np.argmax(examp, axis=1)
which gives an output:
array([0, 0, 3, 1, 1]
Taking the first row as an example, there is a "3-way tie". 3 values of 4
. np.argmax
returns the first index that has the max value. But, how can I detect that there is a "3-way tie" going on and have it decide the tie breaker with a custom function (on the condition that there is at least a "3-way tie" occurring?
So, first row: sees that there is a "3-way tie" of 4s. Custom function runs so that it can decide the tie-breaker.
Second row: "4-way tie" same thing happens.
Third row: only "2-way tie" which is less than condition of at least a "3-way tie". Can default to np.argmax
.
One way for finding the n-th maximum is np.partition
(or np.argpartition
). In this case you can do something like this:
>>> n = 3 # Size of tie
>>> i = examp.argpartition([-n, -1], axis=-1)
The values in the third-to-last and last columns are guaranteed to be in the correct sort order (and therefore the second-to-last as well, but only in this limited case). If those two values are equal to each other, then you have a 3-way tie:
>>> r = np.arange(examp.shape[0])
>>> examp[r, i[:, -n]] == examp[r, i[:, -1]]
array([ True, True, False, False, False])
You can also use np.diff
to compute the mask:
>>> np.diff(examp[r[:, None], i[:, [-n, -1]]], axis=1) == 0
array([[ True],
[ True],
[False],
[False],
[False]])
You can get a similar result by using np.take_along_axis
instead of the first index r
:
>>> np.diff(np.take_along_axis(examp, i[:, -n::n-1], 1), axis=1) == 0
array([[ True],
[ True],
[False],
[False],
[False]])
In all these cases, the value of argmax
is just i[:, -1]
, since that's the index of the maximum value in the array.
Since you are already using numpy, I highly recommend that you vectorize the custom tie-breaking function as well. I've provided the output as a mask here so that you can do exactly that as efficiently as possible.