You can do that pretty easily with Numba:
import numba
@numba.njit('float64[:,:,:,::1](float64[:,:,:,::1], float64[:,:,::1])', fastmath=True, parallel=True)
def compute(x, y):
na, nb, nd, ne = x.shape
nc = y.shape[2]
assert y.shape == (na, nb, nc)
out = np.zeros((nb, nc, nd, ne))
for b in numba.prange(nb):
for a in range(na):
for c in range(nc):
yVal = y[a, b, c]
if np.abs(yVal) != 0:
for d in range(nd):
for e in range(ne):
out[b, c, d, e] += x[a, b, d, e] * yVal
return out
Note that it is faster to iterate over a
and then b
for a sequential code. That being said, for the code to be parallel, the loop have been swapped and the parallelization is performed over b
(which is a small axis). A parallel reduction over the axis a
would be more efficient, but this is unfortunately not easy to do with Numba (one need to split matrices in multiple chunks since there is no simple way to create thread-local matrices).
Note you can replace values like nd
and ne
by the actual value (ie. 30
) so for the compiler to generate a faster code specifically for this matrix size.
Here is the testing code:
np.random.seed(0)
x = np.random.rand(1000, 5, 30, 30)
y = np.random.rand(1000, 5, 300)
y[np.random.rand(*y.shape) > 0.1] = 0.0 # Make it sparse (90% of 0)
%time res = np.einsum('abde,abc->bcde', x, y) # 2.350 s
%time res2 = compute(x, y) # 0.074 s (0.061 s with hand-written sizes)
print(np.allclose(res, res2))
This is about 32 times faster on a 10-core Intel Skylake Xeon processor. It reaches a 38x speed up with hand-written sizes. It does not scale very well due to the parallelization over the b
axis but using other axis will cause a less efficient memory accesses.
If this is not enough, it may be a good idea to transpose x
and y
first so to improve data locality (thanks to a more contiguous access pattern along the a
axis) and a better scaling (by parallelizing both the b
and c
axis). That being said, transpositions are generally expensive so one certainly need to optimize it so to get an even better speed up.