This is a query regarding the internal working of torch.einsum
in the GPU. I know how to use einsum
. Does it perform all possible matrix multiplications, and just pick out the relevant ones, or does it perform only the required computation?
For example, consider two tensors a
and b
, of shape (N,P)
, and I wish to find the dot product of each corresponding tensor ni
, of shape (1,P)
.
Using einsum, the code is:
torch.einsum('ij,ij->i',a,b)
Without using einsum, another way to obtain the output is :
torch.diag(a @ b.t())
Now, the second code is supposed to perform significantly more computations than the first one (eg if N
= 2000
, it performs 2000
times more computation). However, when I try to time the two operations, they take roughly the same amount of time to complete, which begs the question. Does einsum
perform all combinations (like the second code), and picks out the relevant values?
Sample Code to test:
import time
import torch
for i in range(100):
a = torch.rand(50000, 256).cuda()
b = torch.rand(50000, 256).cuda()
t1 = time.time()
val = torch.diag(a @ b.t())
t2 = time.time()
val2 = torch.einsum('ij,ij->i',a,b)
t3 = time.time()
print(t2-t1,t3-t2, torch.allclose(val,val2))