Search code examples
multithreadingmpiopenmpi

Thread Safety with MPI_Probe


I am using MPI_Probe to send messages dynamically (where the receiver does not know the size of the message being sent). My code looks somewhat like this -

if (world_rank == 0) {
    int *buffer = ...
    int bufferSize = ...
    MPI_Send(buffer, buffersize, MPI_INT, 1, 0, MPI_COMM_WORLD);
} else if (world_rank == 1) {
    MPI_Status status;
    MPI_Probe(0, 0, MPI_COMM_WORLD, &status);
    int count = -1;
    MPI_Get_count(&status, MPI_INT, &count);
    int* buffer = (int*)malloc(sizeof(int) * count);
    MPI_Recv(buffer, count, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
}

If I am running this code in multiple threads, is there a chance that MPI_Probe gets called in one thread and the MPI_recv gets called in another thread because of the scheduler interleaving the threads. In essence, is the above code thread-safe.


Solution

  • First of all, MPI is not thread-safe by default. You'll have to check if your particular library has been compiled for thread safety and then initialize MPI using MPI_Init_thread instead of MPI_Init.

    Supposing that your MPI instance is initialized for thread-safe routines, your code is still not thread safe due to the race-condition you already identified.

    The pairing of MPI_Probe and MPI_Recv in a multi-threaded environment is not thread safe, this is a known problem in MPI-2: http://htor.inf.ethz.ch/publications/img/gregor-any_size-mpi3.pdf

    There are at least two possible solutions. You can either use MPI-3 MPI_Mprobe and MPI_MRecv, or use a lock/mutex around the critical code. This could look as follows:

    MPI-2 solution (using a mutex/lock):

    int number_amount;
    if (world_rank == 0) {
        int *buffer = ...
        int bufferSize = ...
        MPI_Send(buffer, buffersize, MPI_INT, 1, 0, MPI_COMM_WORLD);
    } else if (world_rank == 1) {
        MPI_Status status;
        int count = -1;
        /* aquire mutex/lock */
        MPI_Probe(0, 0, MPI_COMM_WORLD, &status);
        MPI_Get_count(&status, MPI_INT, &count);
        int* buffer = (int*)malloc(sizeof(int) * count);
        MPI_Recv(buffer, count, MPI_INT, 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
        /* release mutex/lock */
    }
    

    MPI-3 solution:

    int number_amount;
    if (world_rank == 0) {
        int *buffer = ...
        int bufferSize = ...
        MPI_Send(buffer, buffersize, MPI_INT, 1, 0, MPI_COMM_WORLD);
    } else if (world_rank == 1) {
        MPI_Status status;
        MPI_Message msg;
        int count = -1;
        MPI_Mprobe(0, 0, MPI_COMM_WORLD, &msg, &status);
        MPI_Get_count(&status, MPI_INT, &count);
        int* buffer = (int*)malloc(sizeof(int) * count);
        MPI_Mrecv(buffer, count, MPI_INT, &msg, MPI_STATUS_IGNORE);
    }