Search code examples
pythonnetworkxdaskdask-distributeddask-delayed

Why dask.delayed takes longer than serial code when working with networkx?


I would like to speed up the execution of a function my_func() using parallel computation with dask.delayed.

In a loop over 3 dimensions, my_func() extracts a value from an iris.cube.Cube (which is essentially a dask.array loaded from a file outside the loop), and depending on the value, creates a random network using networkx and finds the shortest path from node 0 to node 16. Calculation for each array point is independent.

  1. Why executing parallel code takes longer (5.43 s) than serial (2.94 s)?
  2. Is there a better way to speed it up with dask or multiprocessing or something else?

Here is a reproducible example:

import random

import dask
import iris
import networkx as nx
from dask import delayed
from dask.distributed import Client
from networkx.generators.random_graphs import gnp_random_graph

# Input
fname = iris.sample_data_path("uk_hires.pp")  # https://anaconda.org/conda-forge/iris-sample-data
temp_ptntl = iris.load_cube(fname, "air_potential_temperature")[-1, ...]  # last time step only
# Dimensions
zs = temp_ptntl.coord("model_level_number").points
lats = temp_ptntl.coord("grid_latitude").points
lons = temp_ptntl.coord("grid_longitude").points

def my_func(iz, iy, ix):
    constraint = iris.Constraint(model_level_number=iz, grid_latitude=iy, grid_longitude=ix)
    temp_virt = temp_ptntl.extract(constraint) * (1 + 0.61 * 0.04)
    if float(temp_virt.data) > 295:
        G = nx.gnp_random_graph(30, 0.2, seed=random.randint(1, 10), directed=True)
        distance, path = nx.single_source_dijkstra(G, source=0, target=16)
    else:
        pass
    return temp_virt, distance, path

Serial code:

%%time
results_serial = [] # serial code
for iz in zs:
    for iy in lats[0:5]:
        for ix in lons[0:2]:
            results_serial.append(my_func(iz, iy, ix))
>>> CPU times: user 2.94 s, sys: 44 ms, total: 2.99 s
>>> Wall time: 2.94 s

Using dask:

client = Client(processes=True, n_workers=4, threads_per_worker=36)
results_parallel = [] # parallel code
for iz in zs:
    for iy in lats[0:5]:
        for ix in lons[0:2]:
            results_parallel.append(delayed(my_func)(iz, iy, ix))
%%time
computed = dask.compute(results_parallel)
>>> CPU times: user 3.56 s, sys: 344 ms, total: 3.91 s
>>> Wall time: 5.43 s
# client.close()

Solution

  • dask will have some overhead, so on small samples it's not unusual for it to underperform. When I try increasing the number of computations by changing to for iy in lats[0:15]:, I see that serial calculation takes 10 seconds, while dask completes it in 4 seconds.

    (there is also serialization of the function, which might take some time, but it applies only to the first time the function is sent to workers)