0

I have a ndarray, and I want to set all the non-maximum elements in the last dimension to be zero.

a = np.array([[[1,8,3,4],[6,7,10,6],[11,12,15,4]],
              [[4,2,3,4],[4,7,9,8],[41,14,15,3]],
              [[4,22,3,4],[16,7,9,8],[41,12,15,43]]
             ])
print(a.shape)
(3,3,4)

I can get the indexes of maximum elements by np.argmax():

b = np.argmax(a, axis=2)
b
array([[1, 2, 2],
       [0, 2, 0],
       [1, 0, 3]])

Obviously, b has 1 dimension less than a. Now, I want to get a new 3-d array that has all zeros except for where the maximum values are.

I want to get this array:

np.array([[[0,1,0,0],[0,0,1,0],[0,0,1,0]],
          [[1,0,0,1],[0,0,1,0],[1,0,0,0]],
          [[0,1,0,0],[1,0,0,0],[0,0,0,1]]
         ])

One way to achieve this, I tried creating these temporary arrays

b = np.repeat(b[:,:,np.newaxis], 4, axis=2)
t = np.repeat(np.arange(4).reshape(4,1), 9, axis=1).T.reshape(b.shape)

z = np.zeros(shape=a.shape, dtype=int)
z[t == b] = 1
z
array([[[0, 1, 0, 0],
    [0, 0, 1, 0],
    [0, 0, 1, 0]],

   [[1, 0, 0, 0],
    [0, 0, 1, 0],
    [1, 0, 0, 0]],

   [[0, 1, 0, 0],
    [1, 0, 0, 0],
    [0, 0, 0, 1]]])

Any idea how to get this in a more efficient way?

Vahid Mirjalili
  • 6,211
  • 15
  • 57
  • 80

1 Answers1

1

Here's one way that uses broadcasting:

In [108]: (a == a.max(axis=2, keepdims=True)).astype(int)
Out[108]: 
array([[[0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 0, 1, 0]],

       [[1, 0, 0, 1],
        [0, 0, 1, 0],
        [1, 0, 0, 0]],

       [[0, 1, 0, 0],
        [1, 0, 0, 0],
        [0, 0, 0, 1]]])
Warren Weckesser
  • 110,654
  • 19
  • 194
  • 214