I want to slice a multidimensional numpy array (>2 dimensions) along 2 of its axes using index slicing. What are the rules for where each of its original dimensions end up?
To illustrate my problem, let me provide an example. Say we have a 4D array:
import numpy as np
a = np.arange(2*3*4*5).reshape(2,3,4,5)
I'll create a tuple of indices using numpy.where, for slicing along axes 1 and 3:
mask = np.where(np.random.rand(3,5) > 0.5)
This will pick out random slices from my array a
. Let's say it returned tuples of length 7.
To preserve the remaining dimensions I will use slice(None)
objects:
b = a[(slice(None), mask[0], slice(None), mask[1])]
This changed the shape:
>>> a.shape
(2, 3, 4, 5)
>>> b.shape
(7, 2, 4)
The axes that were untouched (i.e. sliced using the slice(None)
object) appear to have been preserved, whereas the sliced axes are destroyed and the resulting axis is moved to the front.
However, this is not always the case. When I apply a mask to axes 1 and 2:
mask2 = np.where(np.random.rand(3,4) > 0.5)
c = a[(slice(None), mask[0], mask[1], slice(None))]
I observe the following (numpy.where has returned tuples of length 7 again):
>>> c.shape
(2, 7, 5)
The axis resulting from the axes that have been destroyed by the slicing did not move to the front this time.
My guess is that it is related to whether the sliced axes are adjacent or not, but I want to know from what rules this behavior emerges.