dot product ... along the first dimension
is a bit unclear. Is the first dimension a 'batch' dimension, with 3 dot's on the rest? Or something else?
In [103]: a = np.random.randn(30).reshape(3, 5, 2)
...: b = np.random.randn(30).reshape(3, 2, 5)
In [104]: (a@b).shape
Out[104]: (3, 5, 5)
In [105]: np.einsum('ijk,ikl->ijl',a,b).shape
Out[105]: (3, 5, 5)
@Ivan's
answer is different:
In [106]: np.einsum('ijk,ilm->ijm', a, b).shape
Out[106]: (3, 5, 5)
In [107]: np.allclose(np.einsum('ijk,ilm->ijm', a, b), a@b)
Out[107]: False
In [108]: np.allclose(np.einsum('ijk,ikl->ijl', a, b), a@b)
Out[108]: True
Ivan's sums the k
dimension of one, and l
of the other, and then does a broadcasted elementwise. That is not matrix multiplication:
In [109]: (a.sum(axis=-1,keepdims=True)* b.sum(axis=1,keepdims=True)).shape
Out[109]: (3, 5, 5)
In [110]: np.allclose((a.sum(axis=-1,keepdims=True)* b.sum(axis=1,keepdims=True)),np.einsum('ijk,ilm->ijm', a,
...: b))
Out[110]: True
Another test of the batch processing:
In [112]: res=np.zeros((3,5,5))
...: for i in range(3):
...: res[i] = a[i]@b[i]
...: np.allclose(res, a@b)
Out[112]: True