2

I have a set of contractions that I would like to optimize; for the contractions I am using np.einsum() from the NumPy module. The minimal reproducible example is here:

import numpy as np
from time import time

d1=2
d2=3
d3=100

a = np.random.rand( d1,d1,d1,d1,d2,d2,d2,d2,d3,d3 ) + 1j*np.random.rand( d1,d1,d1,d1,d2,d2,d2,d2,d3,d3 )
b = np.random.rand( d1,d1, d2,d2,d3,d3 ) + 1j*np.random.rand( d1,d1,d2,d2,d3,d3 )
c = np.random.rand( d1,d1, d2,d2,d3,d3 ) + 1j*np.random.rand( d1,d1,d2,d2,d3,d3 )

path_1  = 'abcdefghij,ckgojs,dlhpjs,klmnopqrst->abmnefqrit'
path_2  = 'abcdefghij,ckgoji,nbrfji,klmnopqrij->almdepqhij'

ts = time()
einsum_pathinfo = np.einsum_path(path_1, a, b, c, a )
term_a          = np.einsum(path_1, a, b, c, a, optimize=einsum_pathinfo[0])
print("took", time()-ts)

ts = time()
einsum_pathinfo = np.einsum_path( path_2, a, b, c , a )
term_a          = np.einsum(path_2, a, b, c, a, optimize=einsum_pathinfo[0])
print("took", time()-ts)

The times seem to be around ~2 seconds. I have also observed that einsum is generally not multithreaded, and a single core is used instead. Is there any other efficient way to perform such contractions? (Perhaps with Numba?).

Zarathustra
  • 391
  • 1
  • 12
  • 1
    Your first example consists of 150 times more FLOPs than the second example (1.866e+11 FLOPs vs. 1.244e+09 FLOPs) and is optimized quite well. (should call multithreaded BLAS routines). Generally optimizing einsum contractions depends heavily on the exact path and array shapes (maybe a real world example is needed). This package may also help: https://optimized-einsum.readthedocs.io/en/stable/ – max9111 Jun 03 '22 at 15:30
  • @max9111 - AFAIK, there are no [optimized BLAS routines](https://github.com/numpy/numpy/issues/13229) for `matmul` with complex numbers. – Michael Szczesny Jun 03 '22 at 15:45
  • @MichaelSzczesny Are you sure this haven't been fixed? I get 400ms for the example in the link and 130ms on the same example, but using float64 (both full CPU usage). There is definitely some BLAS call, otherwise there would be a much larger difference. – max9111 Jun 03 '22 at 15:57
  • @MichaelSzczesny There is a BLAS routine for that called `cgemm` (as opposed to `dgemm`) though IDK if Numpy use it (apparently not) – Jérôme Richard Jun 03 '22 at 15:59
  • @max9111 - No, I don't know if this was fixed. I only remembered that issue and checked that it is still open. – Michael Szczesny Jun 03 '22 at 16:03
  • It is not clear if the two example are supposed to provide the same results but in practice the result of the two methods are apparently very different (check using `np.array_equiv` or just by printing some values). – Jérôme Richard Jun 03 '22 at 17:25
  • With a casual read it's hard to tell what those `paths` are doing - which dimensions are sum-of-products, and which just pass through. But often when I've tested `einsum` I found that it's faster to break the action into several `einsum` calls. The more variables you have the larger the problem space, and the longer it takes. The original `einsum` used `nditer` in `cython`, which meant compiled iteration over all variables. Newer versions make more use BLAS/dot calls, but I suspect that has limits when there are many dimensions. – hpaulj Jun 03 '22 at 18:16
  • 1
    Doing this in Numba is completely insane. You end up with at least several groups 15 nested loops! And if someone succeed to write this code without bugs, then optimizing the resulting code is even more insane. I think it would help a lot to simplify the problem first. – Jérôme Richard Jun 03 '22 at 19:29
  • @JérômeRichard This script should do the job: https://github.com/numba/numba/issues/5503 Maybe it helps for the second example. – max9111 Jun 03 '22 at 19:39
  • @max9111 Interesting work. The last comment is talking about a pull request and I see that this issue is now closed. Did you merge something based on this in Numba? – Jérôme Richard Jun 03 '22 at 19:44
  • 1
    @JérômeRichard https://github.com/numba/numba/pull/5514 I guess it has to be simpliefied and thoroughly tested. Also an open question is the code generation which is usually not the way to add features, although often the simplest. – max9111 Jun 03 '22 at 19:59

1 Answers1

1

In case it helps with the discussion, here's the first pathinfo:

In [241]: print(einsum_pathinfo[0])
['einsum_path', (1, 2), (0, 2), (0, 1)]

In [242]: print(einsum_pathinfo[1])
  Complete contraction:  abcdefghij,ckgojs,dlhpjs,klmnopqrst->abmnefqrit
         Naive scaling:  20
     Optimized scaling:  15
      Naive FLOP count:  6.718e+14
  Optimized FLOP count:  1.866e+11
   Theoretical speedup:  3599.750
  Largest intermediate:  1.296e+07 elements
--------------------------------------------------------------------------
scaling                  current                                remaining
--------------------------------------------------------------------------
  10    dlhpjs,ckgojs->cdklghopjs abcdefghij,klmnopqrst,cdklghopjs->abmnefqrit
  15    cdklghopjs,abcdefghij->abklefopis        klmnopqrst,abklefopis->abmnefqrit
  15    abklefopis,klmnopqrst->abmnefqrit                   abmnefqrit->abmnefqrit

Just by string editing I reworked the path1 into

'(abefi)(cg)(dh)j,(cg)(ko)js,(dh)(lp)js,(ko)(lp)(mnqrt)s->(abefi)(mnqrt)'

'abefi' from a just pass through; likewise 'mnqrt' for the last a; 'i' and 't' are (100,) the others are small.

'cg', 'dh', 'ko', 'lp' are contractions between pairs (these are small)

'j' and 's' appear in 3 arrays. (both of these are (100,))

With (as suggested in the comment):

einsum_pathinfo = np.einsum_path(path_1, a, b, c, a ,einsum_call=True)

In [249]: print(einsum_pathinfo[1])
[((2, 1), set(), 'dlhpjs,ckgojs->cdklghopjs', 
 ['abcdefghij', 'klmnopqrst', 'cdklghopjs'], 
 False), 
 ((2, 0), {'d', 'j', 'g', 'c', 'h'}, 'cdklghopjs,abcdefghij->abklefopis', 
 ['klmnopqrst', 'abklefopis'], 
 True), 
 ((1, 0), {'l', 'o', 'k', 's', 'p'}, 'abklefopis,klmnopqrst->abmnefqrit', 
 ['abmnefqrit'], 
 True)]

(2,1) is doing a broadcasted element-wise calculation without any sum-of-products. 'js' is the large (100,100) dimensions pair. The contraction set is empty.

(2,0) does the contraction on {'d', 'j', 'g', 'c', 'h'}, of which j is largest.

(1,0) completes the contraction on {'l', 'o', 'k', 's', 'p'}

hpaulj
  • 221,503
  • 14
  • 230
  • 353
  • The last two contractions are tensordot calls. You can get this output if you set einsum_call=True in the np.einsum_path call (This option isn't documented) – max9111 Jun 03 '22 at 19:22