Search code examples
pythonparallel-processinghuggingface

Repeated wandb.init() in parallelized wandb sweeps


I wrote some code trying to parallelize my wandb sweeps since the model I am working with takes a long time to converge and I have a lot of subprocesses to sweep through. Basically I don’t have the luxury of time right now. Here’s a generalized snippet of my code:

def run_pipeline(args):
    # Stuff happens here

    # Wandb init
    group = "within_session" if session_config["within_session"] else "across_session"
    run = wandb.init(name=f"{sessions[i]}_{group}_decoder_run", group=group, config=sweep_config, reinit=True)

    # Model training

    return results


def run_pipeline_wrapper(args):
    # Stuff happens here
    run_pipeline(args)

    return None


if __name__ == "__main__":
    total_runs = 30
    agents = 5
    runs_per_agent = total_runs // agents

    sweep_config = {'method': 'random'}
    parameters_dict = {
        # Lota of parameters to sweep
    }
    sweep_config['parameters'] = parameters_dict

    # Create a sweep id that stores sweep ids
    sweep_id_json_path = 'sweep_id.json'
    if not os.path.exists(sweep_id_json_path):
        with open(sweep_id_json_path, 'w') as f:
            json.dump({}, f)
    sweep_id_json = json.load(open(sweep_id_json_path, 'r'))

    # Sessions_list = number of unique data that I need to run my sweeps
    for i in range(len(sessions_list)):

        # Preparing a partial method to pass
        run_pipeline_with_args = partial(run_pipeline_wrapper, args)

        # I cache the existing sweep_ids in a json file to help in attaching sweep ids if I rerun the code again
        if f"{sessions_list[i]}_{is_within}" not in sweep_id_json:
            sweep_id = wandb.sweep(sweep_config, project=f"HPC_model_{sess}_session_{data}_{data_type}")
        else:
            sweep_id = wandb.sweep(sweep_config, project=f"HPC_model_{sess}_session_{data}_{data_type}"
                                   , prior_runs=sweep_id_json[f"{sessions_list[i]}_{is_within}"])


        # This is the parallelization logic, where I parallelize the sweeps
        with concurrent.futures.ThreadPoolExecutor(max_workers=agents) as executor:
            futures = [
                executor.submit(wandb.agent, sweep_id, run_pipeline_with_args, count=runs_per_agent)
                for _ in range(agents)
            ]

            concurrent.futures.wait(futures)

When I run this code, it gets stuck on wandb.init(), with that process eventually being terminated due to a timeout. I don’t think this is a problem of increasing wandb’s timeout. How do I fix this? Do you think this might be a problem because of my parallelization logic? If so, how do you devs parallelize your wandb sweeps in-code?

wandb logs


Solution

  • If wandb.init runs I don't think it's a problem with parallelization. In any case, you should check if the arguments were all passed correctly by the executor.
    You have to take into account that ThreadPoolExecutor uses a pool of threads to execute calls asynchronously. If the threads do not run independently and wait on the results of another deadlocks can occur.

    Have you try to parallelize W&B Sweep agents within a Jupyter Notebook, heres the link.

    EDIT:

    You are using threading and in Python threads are bound by a global interpreter lock (GIL), threading won't spread work across CPU cores.
    For CPU-bound work you should use multiprocessing instead of threading. For that Python has the ProcessPoolExecutor. Here is a link to it in Python3 docs.

    If you use Slurm for cluster management in your HPC you should see these links:

    1. How should I run sweeps on SLURM?
    2. Manging wandb agents on a slurm cluster
    3. How does one do hyper parameter sweeps when using HPCs/clusters?