In this case, b[...,None]
is the same as b[:,:,None]
, a (2,3,1) array. ...
means 'as many :
as needed'.
So the dot product sum is with the last 3
of a
and the middle 3
of b
(2nd to the last).
You can use squeeze
to get rid of the size 1 dimension.
But with (2,3,3) and (2,3), which dot
product do you want? In einsum
notation I can see doing
'ijk,ij->ik'
'ijk,ik->ij'
'ijk,mj->imk'
etc
dot
product with 2 2d arrays is well defined. But when one is 3d there's some ambiguity.
In [2]: a=np.arange(18).reshape(2,3,3)
...: b=np.arange(6).reshape(2,3)
...:
In [3]: np.einsum('ijk,ik->ij',a,b)
Out[3]:
array([[ 5, 14, 23],
[122, 158, 194]])
In [4]: np.dot(a,b)
ValueError: shapes (2,3,3) and (2,3) not aligned: 3 (dim 2) != 2 (dim 0)
In [6]: np.dot(a,b[:,:,None]).shape # 'ijk,kml->ijml'
Out[6]: (2, 3, 2, 1)
In [7]: np.matmul(a,b[:,:,None]).shape # @
Out[7]: (2, 3, 1)
In [8]: np.einsum('ijk,ikm->ijm',a,b[...,None])
Out[8]:
array([[[ 5],
[ 14],
[ 23]],
[[122],
[158],
[194]]])
In [12]: np.squeeze(_) # removing that added dimension
Out[12]:
array([[ 5, 14, 23],
[122, 158, 194]])
The relevant notes from matmul
docs are:
If either argument is N-D, N > 2, it is treated as a stack of
matrices residing in the last two indexes and broadcast accordingly.
ValueError - If the last dimension of a
is not the same size as
the second-to-last dimension of b
.
An example of broadcasting in matmul
is:
In [15]: a@b.T
Out[15]:
array([[[ 5, 14],
[ 14, 50],
[ 23, 86]],
[[ 32, 122],
[ 41, 158],
[ 50, 194]]])
In [16]: _.shape
Out[16]: (2, 3, 2)
In [17]: a@b.T[None,:,:]
Out[17]:
array([[[ 5, 14],
[ 14, 50],
[ 23, 86]],
[[ 32, 122],
[ 41, 158],
[ 50, 194]]])
update
I just learned that optimize=True
is now the default for einsum
, and that this isn't always fastest.
In [108]: %timeit np.einsum('ijk,ik->ij',a,b, optimize=False)
5.66 µs ± 63.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [109]: %timeit np.einsum('ijk,ik->ij',a,b, optimize=True)
73 µs ± 69.3 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Einsum optimize fails for basic operation