Search code examples
pythontensorflowdaskdask-distributedjoblib

Running two Tensorflow trainings in parallel using joblib and dask


I have the following code that runs two TensorFlow trainings in parallel using Dask workers implemented in Docker containers.

I need to launch two processes, using the same dask client, where each will train their respective models with N workers.

To that end, I do the following:

  • I use joblib.delayed to spawn the two processes.
  • Within each process I run with joblib.parallel_backend('dask'): to execute the fit/training logic. Each training process triggers N dask workers.

The problem is that I don't know if the entire process is thread safe, are there any concurrency elements that I'm missing?

# First, submit the function twice using joblib delay
delayed_funcs = [joblib.delayed(train)(sub_task) for sub_task in [123, 456]]
parallel_pool = joblib.Parallel(n_jobs=2)
parallel_pool(delayed_funcs)

# Second, submit each training process
def train(sub_task):

    global client
    if client is None:
        print('connecting')
        client = Client()

    data = some_data_to_train

    # Third, process the training itself with N workers
    with joblib.parallel_backend('dask'):
        X = data[columns] 
        y = data[label]

        niceties = dict(verbose=False)
        model = KerasClassifier(build_fn=build_layers,
                loss=tf.keras.losses.MeanSquaredError(), **niceties)
        model.fit(X, y, epochs=500, verbose = 0)

Solution

  • The question, as given, could easily be marked as "unclear" for SO. A couple of notes:

    • global client : makes the client object available outside of the fucntion. But the function is run from another process, you do not affect the other process when making the client
    • if client is None : this is a name error, your code doesn't actually run as written
    • client = Client() : you make a new cluster in each subprocess, each assuming the total resources available, oversubscribing those resources.
    • dask knows whether any client has been created in the current process, but that doesn't help you here

    You must ask yourself: why are you creating processes for the two fits at all? Why not just let Dask figure out its parallelism, which is what it's meant for.

    --

    -EDIT-

    to answer the form of the question asked in a comment.

    My question is whether using the same client variable in these two parallel processes creates a problem.

    No, the two client variables are unrelated to one-another. You may see a warning message about not being able to bind to a default port, which you can safely ignore. However, please don't make it global as this is unnecessary and makes what you are doing less clear.

    --

    I think I must answer the question as phrased in your comment, which I advise to add to the main question

    I need to launch two processes, using the same dask client, where each will train their respective models with N workers.

    You have the following options:

    • create a client with a specific known address within your program or beforehand, then connect to it
    • create a default client Client() and get its address (e.g., client._scheduler_identity['address']) and connect to that
    • write a scheduler information file with client.write_scheduler_file and use that

    You will connect in the function with

    client = Client(address)
    

    or

    client = Client(scheduler_file=the_file_you_wrote)