6

I was trying to implement a conjugate gradient algorithm using Dask (for didactic purposes) when I realized that the performance were way worst that a simple numpy implementation. After a few experiments, I have been able to reduce the problem to the following snippet:

import numpy as np
import dask.array as da
from time import time


def test_operator(f, test_vector, library=np):
    for n in (10, 20, 30):
        v = test_vector()

        start_time = time()
        for i in range(n):
            v = f(v)
            k = library.linalg.norm(v)
    
            try:
                k = k.compute()
            except AttributeError:
                pass
            print(k)
        end_time = time()

        print('Time for {} iterations: {}'.format(n, end_time - start_time))

print('NUMPY!')
test_operator(
    lambda x: x + x,
    lambda: np.random.rand(4_000, 4_000)
)

print('DASK!')
test_operator(
    lambda x: x + x,
    lambda: da.from_array(np.random.rand(4_000, 4_000), chunks=(2_000, 2_000)),
    da
)

In the code, I simply multiply by 2 a vector (this is what f does) and print its norm. When running with dask, each iteration slows down a little bit more. This problem does not happen if I do not compute k, the norm of v.

Unfortunately, in my case, that k is the norm of the residual that I use to stop the conjugate gradient algorithm. How can I avoid this problem? And why does it happen?

Thank you!

SteP
  • 262
  • 1
  • 2
  • 9

1 Answers1

2

I think the code snippet is missusing lazy evaluation in dask, specifically the addition operation. Without optimization, the addition lambda x: x+x is complicating the execution graph, with the depth growing with counter, hence overheads. More precisely, for the counter value i we handle the graph of O(i) when computing the norm, so that the total runtime is O(n**2). Of course optimization is possible and desired, but I stop here as the example shared is synthetic. Below I demonstrate that the graph grows linearly with the counter.

lazy evaluation of operations in dask

To see the quadratic complexity visually, consider the following cleaned version of the snippet in question

import numpy as np
import dask.array as da
from time import time
import matplotlib.pyplot as plt

ns = (10, 20, 40, 50, 60)

def test_operator(f, v, norm):
  out = []
  for n in ns:
    start_time = time()
    for i in range(n):
      v = f(v)
      norm(v)
    end_time = time()
    out.append(end_time - start_time)
  return out


out = test_operator(
    lambda x:x+x,
    np.random.rand(4_000, 4_000),
    norm = np.linalg.norm
)
plt.scatter(ns,out,label='numpy')


out = test_operator(
    lambda x:x+x,
    da.from_array(np.random.rand(4_000, 4_000), chunks=(2_000, 2_000)),
    norm = lambda v: da.linalg.norm(v).compute()
)

plt.scatter(ns,out,label='dask')

plt.legend()
plt.show()

complexity comparison

Maciej Skorski
  • 2,303
  • 6
  • 14
  • 1
    Thank you! I totally think that you are right and that the problem is that Dask does not store the temporary value of v but at every step i it recomputes the values of v starting from the original array test_vector() (and applying i times the function f). – SteP May 03 '23 at 23:01
  • Your answer led me to the following page where they discuss the same problem: https://github.com/dask/dask/issues/4630. Adding the line v = v.persist() right after the computation of v solves the problem and restores the linear behaviour. – SteP May 03 '23 at 23:02