2

If I have an array like below, how can I detect that there is a tie of at least 3 or more values when using np.argmax()?

examp = np.array([[4, 0, 1, 4, 4],
                  [5, 5, 1, 5, 5],
                  [1, 2, 2, 4, 1],
                  [4, 6, 1, 2, 4],
                  [1, 4, 3, 3, 3]])

np.argmax(examp, axis=1)

which gives an output:

array([0, 0, 3, 1, 1]

Taking the first row as an example, there is a "3-way tie". 3 values of 4. np.argmax returns the first index that has the max value. But, how can I detect that there is a "3-way tie" going on and have it decide the tie breaker with a custom function (on the condition that there is at least a "3-way tie" occurring?

So, first row: sees that there is a "3-way tie" of 4s. Custom function runs so that it can decide the tie-breaker.

Second row: "4-way tie" same thing happens.

Third row: only "2-way tie" which is less than condition of at least a "3-way tie". Can default to np.argmax.

Mad Physicist
  • 107,652
  • 25
  • 181
  • 264
Knovolt
  • 95
  • 8
  • You are taking the max indices for each individual row in ```examp```. There is no "4-way tie" here. Although there will be such a tie if you do ```examp.argmax(0)```. – Kevin May 13 '21 at 18:02
  • @Kevin Apologies, I seem to have misunderstood np.argmax(). I've rewritten the question now. (and you already wrote a response too!, sorry). – Knovolt May 13 '21 at 18:16
  • I think the principle is still the same tho, so my answer might be still useful :) – Kevin May 13 '21 at 18:18
  • @Kevin I ended up switching to `examp[np.r_[:5],indices]` and returns `array([4, 5, 4, 6, 4])` which are the max value per row. I'm assuming your examp == checks to see if the original array matches any of these values in this 1D array. But, whether I put `axis=0` or `axis=1`, it checks for it down the columns, rather than row by row. Why does changing the axis not make a difference here? – Knovolt May 13 '21 at 19:13
  • You probably figured this out by now, but when you do ```examp == examp[np.r_[:5],indices]``` you are broadcasting the rows of ```examp[np.r_[:5],indices]```against every row in ```examp```. Although changing axis does make a difference in the counts for me, so not sure what you mean by that. – Kevin May 14 '21 at 15:14
  • Yes, but thank you nonetheless! – Knovolt May 15 '21 at 09:40

2 Answers2

1

One way for finding the n-th maximum is np.partition (or np.argpartition). In this case you can do something like this:

>>> n = 3  # Size of tie
>>> i = examp.argpartition([-n, -1], axis=-1)

The values in the third-to-last and last columns are guaranteed to be in the correct sort order (and therefore the second-to-last as well, but only in this limited case). If those two values are equal to each other, then you have a 3-way tie:

>>> r = np.arange(examp.shape[0])
>>> examp[r, i[:, -n]] == examp[r, i[:, -1]]
array([ True,  True, False, False, False])

You can also use np.diff to compute the mask:

>>> np.diff(examp[r[:, None], i[:, [-n, -1]]], axis=1) == 0
array([[ True],
       [ True],
       [False],
       [False],
       [False]])

You can get a similar result by using np.take_along_axis instead of the first index r:

>>> np.diff(np.take_along_axis(examp, i[:, -n::n-1], 1), axis=1) == 0
array([[ True],
       [ True],
       [False],
       [False],
       [False]])

In all these cases, the value of argmax is just i[:, -1], since that's the index of the maximum value in the array.

Since you are already using numpy, I highly recommend that you vectorize the custom tie-breaking function as well. I've provided the output as a mask here so that you can do exactly that as efficiently as possible.

Mad Physicist
  • 107,652
  • 25
  • 181
  • 264
0

You are correct that np.argmax will only find the first max value. Although you could count how many of these argmax exist and base your logic of that number

indices = examp.argmax(0)
counts = (examp == examp[indices, np.r_[:3]]).sum(0)
# the same as
counts = np.count_nonzero(examp == examp[indices, np.r_[:3]], axis=0)

Will return

indices = array([0, 3, 2])
counts = array([4, 1, 2])
Kevin
  • 3,096
  • 2
  • 8
  • 37