Search code examples
pythonperformancenumba

How to parallelize/speed up embarrassingly parallel numba code?


I have the following code:

@nb.njit(cache=True)
def find_two_largest(arr):
    # Initialize the first and second largest elements
    if arr[0] >= arr[1]:
        largest = arr[0]
        second_largest = arr[1]
    else:
        largest = arr[1]
        second_largest = arr[0]

    # Iterate through the array starting from the third element
    for num in arr[2:]:
        if num > largest:
            second_largest = largest
            largest = num
        elif num > second_largest:
            second_largest = num
    return largest, second_largest


@nb.njit(cache=True)
def max_bar_one(arr):
    largest, second_largest = find_two_largest(arr)
    missing_maxes = np.empty_like(arr)
    for i in range(arr.shape[0]):
        if arr[i] == largest:
            if largest != second_largest:
                missing_maxes[i] = second_largest
            else:
                missing_maxes[i] = largest  # largest == second_largest
        else:
            missing_maxes[i] = largest
    return missing_maxes


@nb.njit(cache=True)
def replace_max_row_wise_add_first_delete_last(d):
    """
    Run max_bar_one on each row but the last, prepend an all -inf row
    """
    m, n = d.shape
    result = np.empty((m, n))
    result[0] = -np.inf
    for i in range(0, m - 1):
        result[i + 1, :] = max_bar_one(d[i, :])
    return result


@nb.njit(cache=True)
def main_function(d, subcusum, j):
    temp = replace_max_row_wise_add_first_delete_last(d)
    for i1 in range(temp.shape[0]):
        for i2 in range(temp.shape[1]):
            temp[i1, i2] = max(temp[i1, i2], d[i1, i2]) + subcusum[j, i2]
    return temp

I then set up the data with:

n = 5000
A = np.random.randint(-3, 4, (n, n)).astype(float)
cusum_rows = np.cumsum(A, axis=1)
rowseq = np.arange(n)
d = np.random.randint(-3, 4, (5000, 5000))

We can then time it with:

%timeit main_function(d, cusum_rows, 0)
166 ms ± 1.87 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

Is it possible to parallelize the for loop or the code in general to speed this up? I tried using parallel=True in replace_max_row_wise_add_first_delete_last but it didn't speed up the code and only reported:

Instruction hoisting:
loop #1:
  Failed to hoist the following:
    dependency: $value_var.73 = getitem(value=_72call__function_11, index=$parfor__index_72.90, fn=<built-in function getitem>)

This is surprising as all the calls in the for loop are independent.

Can this code be sped up and/or parallelized?


Solution

  • I get ~70% speedup when I use parallelization in replace_max_row_wise_add_first_delete_last():

    @nb.njit(parallel=True)
    def replace_max_row_wise_add_first_delete_last_parallel(d):
        """
        Run max_bar_one on each row but the last, prepend an all -inf row
        """
        m, n = d.shape
        result = np.empty((m, n))
        result[0] = -np.inf
        for i in nb.prange(0, m - 1):                 # <-- using prange here
            result[i + 1, :] = max_bar_one(d[i, :])
        return result
    
    
    @nb.njit
    def main_function_parallel(d, subcusum, j):
        temp = replace_max_row_wise_add_first_delete_last_parallel(d)  # <-- using parallel version of the function here
        for i1 in range(temp.shape[0]):
            for i2 in range(temp.shape[1]):
                temp[i1, i2] = max(temp[i1, i2], d[i1, i2]) + subcusum[j, i2]
        return temp
    

    EDIT: Additional speedup is to remove missing_maxes = np.empty_like(arr) temporary allocation. In this case the speedup is 300%:

    @nb.njit
    def max_bar_one2(arr, result, to_compare, to_add):
        largest, second_largest = find_two_largest(arr)
        missing_maxes = result
        for i in range(arr.shape[0]):
            if arr[i] == largest:
                if largest != second_largest:
                    missing_maxes[i] = second_largest
                else:
                    missing_maxes[i] = largest  # largest == second_largest
            else:
                missing_maxes[i] = largest
    
            missing_maxes[i] = max(missing_maxes[i], to_compare[i]) + to_add[i]
        return missing_maxes
    
    
    @nb.njit(parallel=True)
    def replace_max_row_wise_add_first_delete_last_parallel2(d, to_add):
        """
        Run max_bar_one on each row but the last, prepend an all -inf row
        """
        m, n = d.shape
        result = np.empty((m, n))
        result[0] = d[0] + to_add
        for i in nb.prange(0, m - 1):
            max_bar_one2(d[i], result[i + 1], d[i + 1], to_add)
        return result
    
    
    @nb.njit
    def main_function_parallel2(d, subcusum, j):
        return replace_max_row_wise_add_first_delete_last_parallel2(d, subcusum[j])
    

    Benchmark:

    from timeit import timeit
    
    import numba as nb
    import numpy as np
    
    
    @nb.njit(cache=True)
    def find_two_largest(arr):
        # Initialize the first and second largest elements
        if arr[0] >= arr[1]:
            largest = arr[0]
            second_largest = arr[1]
        else:
            largest = arr[1]
            second_largest = arr[0]
    
        # Iterate through the array starting from the third element
        for num in arr[2:]:
            if num > largest:
                second_largest = largest
                largest = num
            elif num > second_largest:
                second_largest = num
        return largest, second_largest
    
    
    @nb.njit(cache=True)
    def max_bar_one(arr):
        largest, second_largest = find_two_largest(arr)
        missing_maxes = np.empty_like(arr)
        for i in range(arr.shape[0]):
            if arr[i] == largest:
                if largest != second_largest:
                    missing_maxes[i] = second_largest
                else:
                    missing_maxes[i] = largest  # largest == second_largest
            else:
                missing_maxes[i] = largest
        return missing_maxes
    
    
    @nb.njit(cache=True)
    def replace_max_row_wise_add_first_delete_last(d):
        """
        Run max_bar_one on each row but the last, prepend an all -inf row
        """
        m, n = d.shape
        result = np.empty((m, n))
        result[0] = -np.inf
        for i in range(0, m - 1):
            result[i + 1, :] = max_bar_one(d[i])
        return result
    
    
    @nb.njit(cache=True)
    def main_function(d, subcusum, j):
        temp = replace_max_row_wise_add_first_delete_last(d)
        for i1 in range(temp.shape[0]):
            for i2 in range(temp.shape[1]):
                temp[i1, i2] = max(temp[i1, i2], d[i1, i2]) + subcusum[j, i2]
        return temp
    
    
    @nb.njit(parallel=True)
    def replace_max_row_wise_add_first_delete_last_parallel(d):
        """
        Run max_bar_one on each row but the last, prepend an all -inf row
        """
        m, n = d.shape
        result = np.empty((m, n))
        result[0] = -np.inf
        for i in nb.prange(0, m - 1):
            result[i + 1, :] = max_bar_one(d[i, :])
        return result
    
    
    @nb.njit
    def main_function_parallel(d, subcusum, j):
        temp = replace_max_row_wise_add_first_delete_last_parallel(d)
        for i1 in range(temp.shape[0]):
            for i2 in range(temp.shape[1]):
                temp[i1, i2] = max(temp[i1, i2], d[i1, i2]) + subcusum[j, i2]
        return temp
    
    
    @nb.njit
    def max_bar_one2(arr, result, to_compare, to_add):
        largest, second_largest = find_two_largest(arr)
        missing_maxes = result
        for i in range(arr.shape[0]):
            if arr[i] == largest:
                if largest != second_largest:
                    missing_maxes[i] = second_largest
                else:
                    missing_maxes[i] = largest  # largest == second_largest
            else:
                missing_maxes[i] = largest
    
            missing_maxes[i] = max(missing_maxes[i], to_compare[i]) + to_add[i]
        return missing_maxes
    
    
    @nb.njit(parallel=True)
    def replace_max_row_wise_add_first_delete_last_parallel2(d, to_add):
        """
        Run max_bar_one on each row but the last, prepend an all -inf row
        """
        m, n = d.shape
        result = np.empty((m, n))
        result[0] = d[0] + to_add
        for i in nb.prange(0, m - 1):
            max_bar_one2(d[i], result[i + 1], d[i + 1], to_add)
        return result
    
    
    @nb.njit
    def main_function_parallel2(d, subcusum, j):
        return replace_max_row_wise_add_first_delete_last_parallel2(d, subcusum[j])
    
    
    def get_d_cumsum_rows(n):
        A = np.random.randint(-300, 400, (n, n)).astype(float)
        cusum_rows = np.cumsum(A, axis=1)
        d = np.random.randint(-300, 400, (n, n))
    
        return d, cusum_rows
    
    
    n = 10
    np.random.seed(42)
    out1 = main_function(*get_d_cumsum_rows(n), 0)
    
    np.random.seed(42)
    out2 = main_function_parallel(*get_d_cumsum_rows(n), 0)
    
    np.random.seed(42)
    out3 = main_function_parallel2(*get_d_cumsum_rows(n), 0)
    
    assert np.allclose(out1, out2)
    assert np.allclose(out1, out3)
    
    t1 = timeit(
        "main_function(a, b, 0)",
        setup="n=5000;a,b=get_d_cumsum_rows(n)",
        globals=globals(),
        number=100,
    )
    
    t2 = timeit(
        "main_function_parallel(a, b, 0)",
        setup="n=5000;a,b=get_d_cumsum_rows(n)",
        globals=globals(),
        number=100,
    )
    
    t3 = timeit(
        "main_function_parallel2(a, b, 0)",
        setup="n=5000;a,b=get_d_cumsum_rows(n)",
        globals=globals(),
        number=100,
    )
    
    print(t1)
    print(t2)
    print(t3)
    

    Prints on my computer (AMD 5700x):

    7.003944834927097
    4.12014868715778
    2.2788363839499652