Hello I have following data
ids = np.concatenate([1.0 * np.ones(shape=(4, 9,)),
2.0 * np.ones(shape=(4, 3,))], axis=1)
logits = np.random.normal(size=(4, 9 + 3, 256))
Now I want to get numpy array only of ids that have 1.0 and I want to get array of size (4,9, 256)
I tried logits[ids == 1.0, :]
but I get (36, 256)
How I can make slicing without connecting first two dimensions ?
Current dimensions are only example ones and I am looking for generic solution.