Similarly to the question in Pytorch batch matrix vector outer product I have two matrices and would like to compute their outer product, or in other words the pairwise elementwise product.
Shape example:
If we have X1 and X2 of shapes of torch.Size([32, 300, 8])
The result should be of size torch.Size([32, 300, 300, 8])