1

I find myself reshaping 1D vectors way to many times. I wonder if this is because I'm doing something wrong, or because it is an inherit fault of numpy.

Why can't numpy infer that when he gets an object of shape (400,) to transform it to (400,1) ? And why do so many numpy operations result in removing the axis completely?

e.g.

def predict(Theta1, Theta2, X):
    m = X.shape[0]
    X = np.c_[np.ones(m), X]
    hidden = sigmoid(X @ Theta1.T)
    hidden = np.c_[np.ones(m), hidden]
    output = sigmoid(hidden @ Theta2.T)
    result = np.argmax(output, axis=1) + 1 # removes the 2nd axis - (400,)
    return result.reshape((-1, 1)) # re-add the axis - (400,1)

pred = predict(Theta1, Theta2, X)
print(np.mean(pred == y))

If I don't reshape the result in the last row, I get funky behavior when comparing pred (400,) and y (400,1).

Maverick Meerkat
  • 5,737
  • 3
  • 47
  • 66
  • 1
    Well by applying `argmax` over axis=1, you are essentially getting rid of that axis, as you are aggregating on it. So given that what you need comes from the output of this function, and you need a 2d shaped array, there's not much way around this other han reshaping – yatu Sep 10 '19 at 07:33
  • 3
    You can use `result[:, None]`, which does the same but may be a bit more readable to some users. – Nils Werner Sep 10 '19 at 08:15
  • 4
    I find `result[..., np.newaxis]` way more readable, but I know it's a somewhat unpopular opinion here – filippo Sep 10 '19 at 08:28
  • 1
    It would be nice if NumPy added the argument `keepdims` to `argmin` and `argmax`: https://github.com/numpy/numpy/issues/8710 – Warren Weckesser Sep 10 '19 at 11:58
  • These are nice suggestions, though I still don't understand what is the problem when you have `pred == y` for (400,) and (400,1) for it to work. Seems like simple inference. – Maverick Meerkat Sep 11 '19 at 07:40

2 Answers2

-1

you can use

np.array_split(data, s)

knowing that new dimensions will have length of s (data shape is s * s)

salouri
  • 760
  • 1
  • 8
  • 19
-1

The new numpy version (1.22) now added an optional keepdims to argmax. Source: here.

Maverick Meerkat
  • 5,737
  • 3
  • 47
  • 66