Given an index array index
and, say, a matrix A
I want a matrix B
with the corresponding permutation of the columns of A
.
In Numpy I would do the following,
>>> A = np.arange(6).reshape(2,3); A
array([[0, 1, 2],
[3, 4, 5]])
>>> index = [2,0,1]
>>> A[:,index]
array([[2, 0, 1],
[5, 3, 4]])
Is there a natural or efficient way to do this in MXNet? The functions pick()
and take()
don't seem to work in this way. I managed to come up with the following but it's not elegant.
>>> mx.nd.take(A.T, mx.nd.array([[2],[0],[1]])).T.reshape((2,3))
[[ 2. 0. 1.]
[ 5. 3. 4.]]
<NDArray 2x3 @cpu(0)>
Finally, to throw a wrench into the works, is there a way to do this in-place?
Update Here is a slightly more elegant, but presumably not as efficient (due to the transposition), version of above:
>>> mx.nd.take(A.T, mx.nd.array([2,0,1])).T
[[ 2. 0. 1.]
[ 5. 3. 4.]]
<NDArray 2x3 @cpu(0)>