3

This is a query regarding the internal working of torch.einsum in the GPU. I know how to use einsum. Does it perform all possible matrix multiplications, and just pick out the relevant ones, or does it perform only the required computation?

For example, consider two tensors a and b, of shape (N,P), and I wish to find the dot product of each corresponding tensor ni, of shape (1,P). Using einsum, the code is:

torch.einsum('ij,ij->i',a,b)

Without using einsum, another way to obtain the output is :

torch.diag(a @ b.t())

Now, the second code is supposed to perform significantly more computations than the first one (eg if N = 2000, it performs 2000 times more computation). However, when I try to time the two operations, they take roughly the same amount of time to complete, which begs the question. Does einsum perform all combinations (like the second code), and picks out the relevant values?

Sample Code to test:

import time
import torch
for i in range(100):
  a = torch.rand(50000, 256).cuda()
  b = torch.rand(50000, 256).cuda()

  t1 = time.time()
  val = torch.diag(a @ b.t())
  t2 = time.time()
  val2 = torch.einsum('ij,ij->i',a,b)
  t3 = time.time()
  print(t2-t1,t3-t2, torch.allclose(val,val2))
OlorinIstari
  • 537
  • 5
  • 20

2 Answers2

2

It probably has to do with the fact that the GPU can parallelize the computation of a @ b.t(). This means that the GPU doesn't actually have to wait for each row-column multiplication computation to finish to compute then next multiplication. If you check on CPU then you see that torch.diag(a @ b.t()) is significantly slower than torch.einsum('ij,ij->i',a,b) for large a and b.

Gil Pinsky
  • 2,388
  • 1
  • 12
  • 17
  • So the takeaway is that, whenever it is possible, it is always preferable to use ```einsum``` for computational (and time) efficiency, and I can trust einsum to perform only the necessay computations? (For my case, I typically have to deal with 4D arrays and perform all sorts of such operations) – OlorinIstari Sep 07 '20 at 14:24
  • Generally speaking ```torch.einsum``` will not necessarily be the most efficient in memory and time (see https://pytorch.org/docs/stable/generated/torch.einsum.html). There are project like opt_einsum (see https://optimized-einsum.readthedocs.io/en/stable/) that may give you a more efficient implementation. – Gil Pinsky Sep 07 '20 at 14:35
  • 1
    I'll check them out. – OlorinIstari Sep 07 '20 at 15:25
0

I can't speak for torch, but have worked with np.einsum in some detail years ago. Then it constructed a custom iterator based on the index string, doing only the necessary calculations. Since then it's been reworked in various ways, and evidently converts the problem to a @ where possible, and thus taking advantage of BLAS (etc) library calls.

In [147]: a = np.arange(12).reshape(3,4)
In [148]: b = a

In [149]: np.einsum('ij,ij->i', a,b)
Out[149]: array([ 14, 126, 366])

I can't say for sure what method is used in this case. With the 'j' summation, it could also be done with:

In [150]: (a*b).sum(axis=1)
Out[150]: array([ 14, 126, 366])

As you note, the simplest dot creates a larger array from which we can pull the diagonal:

In [151]: (a@b.T).shape
Out[151]: (3, 3)

But that's not the right way to use @. @ expands on np.dot by providing an efficient 'batch' handling. So the i dimension is the batch one, and j the dot one.

In [152]: a[:,None,:]@b[:,:,None]
Out[152]: 
array([[[ 14]],

       [[126]],

       [[366]]])
In [156]: (a[:,None,:]@b[:,:,None])[:,0,0]
Out[156]: array([ 14, 126, 366])

In other words it is using a (3,1,4) with (3,4,1) to produce a (3,1,1), doing the sum of products on the shared size 4 dimension.

Some sample times:

In [162]: timeit np.einsum('ij,ij->i', a,b)
7.07 µs ± 89.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [163]: timeit (a*b).sum(axis=1)
9.89 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [164]: timeit np.diag(a@b.T)
10.6 µs ± 31.4 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
In [165]: timeit (a[:,None,:]@b[:,:,None])[:,0,0]
5.18 µs ± 197 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
hpaulj
  • 221,503
  • 14
  • 230
  • 353