Here's my problem. I have two matrices A
and B
, with complex entries, of dimensions (n,n,m,m)
and (n,n)
respectively.
Below is the operation I perform to get a matrix C
-
C = np.sum(B[:,:,None,None]*A, axis=(0,1))
Computing the above once takes about 6-8 seconds. Since I have to compute many such C
s, it takes a lot of time. Is there a faster way to do this? (I'm doing these using JAX NumPy on a multi-core CPU; normal NumPy takes even longer)
n=77
and m=512
, if you are wondering. I can parallelize as I'm working on a cluster, but the sheer size of the arrays consumes a lot of memory.