-1

I have some 4-dimensional numpy arrays for which the easiest visualisation is a matrix of arbitrary size (not necessarily square) in which each element is a 2x2 square matrix. I would like to standard matrix multiply (@) the 2x2 matrices of the large matrices elementwise (producing another matrix of the same dimension of 2x2 matrices). The eventual hope is to parallelize this process using CuPy so I want to do this without resorting to looping over every element of the bigger matrix.

Any help would be appreciated.

Example of what I mean:

x = np.array([[  [[1,0], [0, 1]], [[2,2], [2, 1]]  ]])
y = np.array([[  [[1,3], [0, 1]], [[2,0], [0, 2]]  ]])
xy = np.array([[  [[1,3], [0, 1]], [[4,4], [4, 2]]  ]])


[[ [[1, 0],            [[2, 2]          x         [[ [[1, 3],            [[2, 0]
    [0, 1]]      ,      [2, 1]] ]]                    [0, 1]]      ,      [0, 2]] ]]




=> [[ [[1, 3],            [[4, 4]
       [0, 1]]      ,      [4, 2]] ]]

In this example the 2 'large' matrices are 1x2 matrices where each of the 2 elements are 2x2 matrices. I have tried to lay it out in a manner that makes it clear what is going on as well as using standard 4d numpy arrays.

Edited in line with comments.

  • Please use named, typed and copyable *standard 4d numpy arrays* as examples. – Michael Szczesny Jul 17 '22 at 20:31
  • Just use ```np.matmul```. As the documentation reads: "If either argument is N-D, N > 2, it is treated as a stack of matrices residing in the last two indexes and broadcast accordingly." – Homer512 Jul 17 '22 at 20:36
  • For someone who is used to creating and viewing 4d arrays, in the normal `numpy` manner, your layout is confusing. Also it doesn't help recreate your problem, since it can't be copy-n-pasted. – hpaulj Jul 17 '22 at 20:44

1 Answers1

0

As Homer512 stated in a comment, np.matmul, aka the @ operator, will handle this scenario (see the numpy docs). You will need to make sure your 2 x 2 matrices are in the last dimensions.

import numpy as np

a1 = np.array([[1, 0], [0, 1]])
a2 = np.array([[2, 2], [2, 1]])
a = np.array([a1, a2])

b1 = [[1, 3], [0, 1]]
b2 = [[2, 0], [0, 2]]
b = np.array([b1, b2])

x = np.array([a, b])

print(a @ b)

Output:

[[[1 3]
  [0 1]]

 [[4 4]
  [4 2]]]
ogdenkev
  • 2,264
  • 1
  • 10
  • 19