Search code examples
pythonmapreducedask

What happens during dask Client.map() call?


I am attempting to write an grid-search utility using dask. The objective function calls a method of a class which contains a large datafame. I am attempting to use dask to parallelise the computations to a many-core solution without having to copy the original class/dataframe. I have not found any solutions in the documentation so I'm posting a toy example here:

import pickle
from dask.distributed import Client, LocalCluster
from multiprocessing import current_process


class TestClass:
    def __init__(self):
        self.param = 0

    def __getstate__(self):
        print("I am pickled!")
        return self.__dict__

    def loss(self, ext_param):
        self.param += 1
        print(f"{current_process().pid}: {hex(id(self))}:  {self.param}: {ext_param} ")
        return f"{self.param}_{ext_param}"


def objective_function(param):
    return test_instance.loss(param)

if __name__ == '__main__':

    test_instance = TestClass()
    print(hex(id(test_instance)))
    cluster = LocalCluster(n_workers=2)
    client = Client(cluster)
    futures = client.map(objective_function, range(20))
    result = client.gather(futures)
    print(result)
    
# ---- OUTPUT RESULTS ----
# 0x7fe0a5056d30
# I am pickled!
# I am pickled!
# 11347: 0x7fb9bcfa0588:  1: 0
# 11348: 0x7fb9bd0a2588:  1: 1
# 11347: 0x7fb9bcf94240:  1: 2
# 11348: 0x7fb9bd07b6a0:  1: 3
# 11347: 0x7fb9bcf945f8:  1: 4 
# ['1_0', '1_1', '1_2', '1_3', '1_4']

I have the following questions:

  1. Why is the following pickle function called twice?
  2. I notice each of the iterations of the map function use a fresh copy of the test_instance, as you can see from the different class address on each of the iterations as well as from the fact that the test_instance.param attribute is set to 0 at each iteration (this behaviour is different from the standard implementation of multiprocessing.Pool I have highlighted here). I am assuming that during each iteration each process will receive a fresh copy of the pickled class - is that correct?
  3. Following from (2), how many copies of test_instance are in memory during computation? Is it 1 (for the original instance in main thread) + 1 (pickled copy) + 2 (instances present in each of the processes) = 4 ? Is there any way to get this value to 1?

I have noticed that some shared memory solutions are available via using Ray library as proposed in this github issue.


Solution

  • Why is the following pickle function called twice?

    Normally, python's pickle efficiently bundles the instance variables and reference to the class in an imported module. In __main__, this can be unreliable, and dask falls back to cloudpickle (which also calls pickle internally). It looks to me like the check for being in "__main__" in distributed.protocol.pickle.dumps could happen before the first attempt to pickle.

    during each iteration each process will receive a fresh copy of the pickled class

    Yes. Each time dask runs a task, it deserialises the inputs, creating a nw copy of the instance. Note that your dask workers are probably created via the fork_server technique, so memory is not simply copied (this is the safe way to do things).

    You could "scatter" the instance to workers before computing, and they can reuse their local copy, but dask tasks are not supposed to work by mutating objects, but by returning results (i.e., functionally).

    how many copies of test_instance are in memory

    1 in the client, plus one per task being executed. Serialised versions may also be around, probably one held in the graph, which is temporarily on the client, and then held on the scheduler; it will also be temporarily in worker memory while being deserialised. For some types, zero-copy de/ser is possible.

    If the tasks are very big because of the size of the object, you should definitely "scatter" them beforehand (client.scatter).

    Is there any way to get this value to 1?

    You can run the scheduler and/or workers in-process to share memory, but, of course, then you lose parallelism to the GIL.

    Maybe you can try the Actor interface? The pattern seems to match your workflow.