Search code examples
pythonoptimizationpytorchparallel-processinglinear-algebra

Pytorch parallelize linalg.solve() loop


Goal: I have a function calling torch.linalg.solve() that I want to run as fast as I can.

Setup: I have an input_array (size 50x100x100). I have a host_array (size 100x100x100?. My function solver inputs input_array[:,i,j] and outputs a vector size 100 to store in host_array[:,i,j]. I am running a nested loop through all rows and columns of input_array to populate host_array.

Problem: This is slow to run, especially considering my real case where each call of the function takes a second. I am running it with a nested loop, and I would like to know if it would be faster by parallelizing my function ?

Example Code:

import torch
from tqdm import tqdm

# Create host_array and input_array with random data
host_array = torch.zeros(100, 500, 500)
input_array = torch.randn(50, 500, 500)

# Create a dummy coefficient matrix A (50x100)
A = torch.randn(50, 100)

# Define your function to solve for input_array[:, i, j] and update host_array[:, i, j]
def solver(input_vector, A):

    # Solve the linear system of equation
    solution = torch.linalg.solve(A.T@A, A.T@input_vector)

    return solution

# Calculate total runtime
total_iterations = int(host_array.shape[1]*host_array.shape[2])
progress_bar = tqdm(total=total_iterations, dynamic_ncols=False, mininterval=1.0)

# Iterate through the input_array
for i in range(host_array.shape[1]):
    for j in range(host_array.shape[2]):
        host_array[:,i,j] = solver(input_array[:,i,j], A)
        progress_bar.update(1)

Solution

  • You can make use of the broadcasting capabilities of torch.linalg.solve() to gain a significant speedup – see the section # proposed solution and, in particular, the function solver_batch() in my code below. I annotated the shapes that result from the necessary reshapings (squeezing, unsqueezing, and permuting) of the inputs.

    from datetime import datetime
    import torch
    
    torch.manual_seed(42)  # Make result reproducible
    
    B, C = 500, 500
    M, N = 100, 50
    
    input_array = torch.randn(N, B, C)
    A = torch.randn(N, M)
    
    # Proposed solution
    
    t_start = datetime.now()
    def solver_batch(inpt, a):
        at_a = a.T @ a  # MxM
        inpt = inpt.permute(1, 2, 0).unsqueeze(-1)  # BxCxNx1
        at_input = a.T @ inpt  # BxCxMx1
        
        result = torch.linalg.solve(at_a, at_input)  # BxCxMx1
        return result.squeeze(-1).permute(2, 0, 1)  # MxBxC
    
    proposed_result = solver_batch(input_array, A)
    t_stop = datetime.now()
    print(f"Proposed solution took {(t_stop - t_start).total_seconds():.2f}s")
    
    # Previous solution
    
    t_start = datetime.now()
    host_array = torch.zeros(M, B, C)  # Will hold the result
    
    def solver(input_vector, A):
        return torch.linalg.solve(A.T@A, A.T@input_vector)
    
    for i in range(host_array.shape[1]):
        for j in range(host_array.shape[2]):
            host_array[:,i,j] = solver(input_array[:,i,j], A)
    t_stop = datetime.now()
    print(f"Previous solution took {(t_stop - t_start).total_seconds():.2f}s")
    
    # Check results
    
    left = (A.T @ A) @ host_array.permute(1, 2, 0).unsqueeze(-1)
    right = A.T @ input_array.permute(1, 2, 0).unsqueeze(-1)
    print("(A.T @ A) @ host_array      == A.T @ input_array?",
          torch.allclose(left, right, atol=1e-3))
    
    left = (A.T @ A) @ proposed_result.permute(1, 2, 0).unsqueeze(-1)
    right = A.T @ input_array.permute(1, 2, 0).unsqueeze(-1)
    print("(A.T @ A) @ proposed_result == A.T @ input_array?",
          torch.allclose(left, right, atol=1e-3))
    

    On my machine, I got:

    >>> Proposed solution took 0.62s
    >>> Previous solution took 21.74s
    >>> (A.T @ A) @ host_array      == A.T @ input_array? True
    >>> (A.T @ A) @ proposed_result == A.T @ input_array? True
    

    Note that, while both host_array and proposed_result hold valid solutions, they do not necessarily hold the same solution (in fact, for the given random seed, they are not the same). This, if I understand correctly, is because the result of torch.linalg.solve() is unique if and only if its first argument (A.T @ A in our case) is invertible, which it does not necessarily seem to be. If you want to check that both your and my approach indeed produce the same solution (subject to numerical error) for an invertible matrix, you could construct yourself an MxM invertible matrix following this approach and replace A.T @ A with it for testing purposes.

    Also note that, when comparing the result with torch.allclose(), I had to be quite generous (atol=1e-3) as quite some numerical error seems to build up.