Search code examples
pythonnumpydask

Computing a norm in a loop slows down the computation with Dask


I was trying to implement a conjugate gradient algorithm using Dask (for didactic purposes) when I realized that the performance were way worst that a simple numpy implementation. After a few experiments, I have been able to reduce the problem to the following snippet:

import numpy as np
import dask.array as da
from time import time


def test_operator(f, test_vector, library=np):
    for n in (10, 20, 30):
        v = test_vector()

        start_time = time()
        for i in range(n):
            v = f(v)
            k = library.linalg.norm(v)
    
            try:
                k = k.compute()
            except AttributeError:
                pass
            print(k)
        end_time = time()

        print('Time for {} iterations: {}'.format(n, end_time - start_time))

print('NUMPY!')
test_operator(
    lambda x: x + x,
    lambda: np.random.rand(4_000, 4_000)
)

print('DASK!')
test_operator(
    lambda x: x + x,
    lambda: da.from_array(np.random.rand(4_000, 4_000), chunks=(2_000, 2_000)),
    da
)

In the code, I simply multiply by 2 a vector (this is what f does) and print its norm. When running with dask, each iteration slows down a little bit more. This problem does not happen if I do not compute k, the norm of v.

Unfortunately, in my case, that k is the norm of the residual that I use to stop the conjugate gradient algorithm. How can I avoid this problem? And why does it happen?

Thank you!


Solution

  • I think the code snippet is missusing lazy evaluation in dask, specifically the addition operation. Without optimization, the addition lambda x: x+x is complicating the execution graph, with the depth growing with counter, hence overheads. More precisely, for the counter value i we handle the graph of O(i) when computing the norm, so that the total runtime is O(n**2). Of course optimization is possible and desired, but I stop here as the example shared is synthetic. Below I demonstrate that the graph grows linearly with the counter.

    lazy evaluation of operations in dask

    To see the quadratic complexity visually, consider the following cleaned version of the snippet in question

    import numpy as np
    import dask.array as da
    from time import time
    import matplotlib.pyplot as plt
    
    ns = (10, 20, 40, 50, 60)
    
    def test_operator(f, v, norm):
      out = []
      for n in ns:
        start_time = time()
        for i in range(n):
          v = f(v)
          norm(v)
        end_time = time()
        out.append(end_time - start_time)
      return out
    
    
    out = test_operator(
        lambda x:x+x,
        np.random.rand(4_000, 4_000),
        norm = np.linalg.norm
    )
    plt.scatter(ns,out,label='numpy')
    
    
    out = test_operator(
        lambda x:x+x,
        da.from_array(np.random.rand(4_000, 4_000), chunks=(2_000, 2_000)),
        norm = lambda v: da.linalg.norm(v).compute()
    )
    
    plt.scatter(ns,out,label='dask')
    
    plt.legend()
    plt.show()
    

    complexity comparison