Goal: I have a function called by a loop that inputs a 1D tensor, and a 2D tensor. I use torch.linalg.solve()
in this function. I want to parallelize the loop to optimize the runtime.
Setup: I have 3 main tensors:
input_tensor
: size 50x100x100host_tensor
: size 100x100x100A
: size 50x100 (design matrix)input_tensor
has 100x100 input_vector
, all length 50. They also all have a different amount of NaNs that I mask, hence the input_vector
masked having a length inferior or equal to 50. Note that the design matrix A
will also be masked and have size (mask x 100).
Because each input_vector
and A
have different masked lengths, the function needs to be run point-by-point.
Problem: Is there a way to make the following code faster ? How could I deal with the design matrix A
and input_vector
having different sizes at each iteration ?
Important: The NaNs can not be replaced by 0 as this would defeat the process of the linear solving process. As background, I asked a question about similar process here.
Code:
import torch
from tqdm import tqdm
import numpy as np
from datetime import datetime
# Create "device" so we can migrate the tensors to GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Set the seed for reproducibility
torch.manual_seed(42)
# Set shapes to generate tensors
B, C = 500, 500
M, N = 100, 50
# Generate tensors
input_tensor = torch.randn(N, B, C)
host_tensor = torch.randn(M, B, C)
A = torch.randn(N, M)
# --- Here we input random NaNs in the input_tensor to simulate missing data --- #
# Define the probability of inserting NaN at each element
probability = 0.2 # You can adjust this as needed
# Generate random indices based on the probability
shape = input_tensor.shape
random_indices = torch.rand(shape) < probability
# Replace the selected indices with NaN values
input_tensor[random_indices] = float('nan')
# --- Migrate matrices to GPU --- #
A = A.to(device)
input_tensor = input_tensor.to(device)
host_tensor = host_tensor.to(device)
A = A.to(device)
t_start = datetime.now()
# --- Function that creates a vector size M from input_vector (size N) and A
def solver(input_vector, A):
# We create a mask to reduce the row size of A: rows where input_vector is NaN are not considered in the solver
mask = ~torch.isnan(input_vector)
# Mask the vector
input_vector_masked = input_vector[mask]
# Mask the array
A_masked = A[mask]
A_trans = A_masked.T
# Solve the linear system of equation: A.TA = A.Tvec_Obs
return torch.linalg.solve(A_trans@A_masked, A_trans@input_vector_masked)
# --- Iterate through each vector of the input_tensor --- #
# Define the total number of iterations
total_iterations = B*C
# Create a tqdm progress bar
progress_bar = tqdm(total=total_iterations, dynamic_ncols=False, mininterval=1.0)
# Iterate through every cell of input_array
for i in range(host_tensor.shape[1]):
for j in range(host_tensor.shape[2]):
host_tensor[:,i,j] = solver(input_tensor[:,i,j], A)
progress_bar.update(1) # Update the progress bar
t_stop = datetime.now()
print(f"Inversion took {(t_stop - t_start).total_seconds():.2f}s")
I got a bit of an unsatisfactory answer here. But let's go step by step.
nan
s == dropping nan
sFirst of all, you can replace the nan
s with zeros. Take the following example: Assume you have a vector v
and a matrix A
, given as
v = [v1 v2 v3] # N elements
A = [[a11 a12 a13] # NxM elements
[a21 a22 a23]
[a31 a32 a33]]
Now, assume v2 = nan
and thus needs to be suppressed.
What you are currently doing in solver()
is getting the non-nan
elements of v
as m
, the corresponding rows of A
as M
and then calculate A_for_solving = M.T @ M
and B_for_solving = M.T @ v
, namely
m = [v1 v3] # Masked v (n < N elements)
M = [[a11 a12 a13] # Masked A (nxM elements)
[a31 a32 a33]]
A_for_solving = M.T @ M # MxM elements
B_for_solving = M.T @ m # M elements
result = linalg.solve(A_for_solving, B_for_solving)
You should notice two things here:
The shapes of A_for_solving
and B_for_solving
always remain the same, no matter how many elements from v
(and thus rows from A
) are dropped: A_for_solving
is always an M×M matrix and B_for_solving
is always an M-element vector. This hints at the possibility that we can actually still parallelize our calculation.
What's more, if you would replace the nan
s in v
and the corresponding rows in A
with zeros, you would yield exactly the same values in A_for_solving
and B_for_solving
!
In other words, you could do the following:
z = [v1 0 v3] # Zeroed v
Z = [[a11 a12 a13] # Zeroed A
[ 0 0 0]
[a31 a32 a33]]
A_for_solving = Z.T @ Z
B_for_solving = Z.T @ z
result = linalg.solve(A_for_solving, B_for_solving)
… and you would get exactly the same inputs to linalg.solve()
as before!
You can easily check this with your current code by extending it for testing purposes as follows:
def solver(input_vector, A):
mask = ~torch.isnan(input_vector)
input_vector_masked = input_vector[mask]
A_masked = A[mask]
A_trans = A_masked.T
# Start sanity check: nan-zeroing is the same as nan-dropping
A_zeroed = A.clone(); A_zeroed[~mask] = 0
input_vector_zeroed = input_vector.clone(); input_vector_zeroed[~mask] = 0
assert torch.allclose(A_masked.T @ A_masked,
A_zeroed.T @ A_zeroed, atol=1e-5)
assert torch.allclose(A_masked.T @ input_vector_masked,
A_zeroed.T @ input_vector_zeroed, atol=1e-5)
# End sanity check
return torch.linalg.solve(A_trans@A_masked, A_trans@input_vector_masked)
If we use the zeroing approach, we can parallelize our code again, as we now have inputs of the same size for each mask again. The corresponding function could look as follows:
def solver_batch(inpt, a):
inpt = inpt.permute(1, 2, 0).unsqueeze(-1) # BxCxNx1
mask = torch.isnan(inpt) # CAUTION: True for NaNs, unlike `mask` in the question!
a_zeroed = a.repeat(*inpt.shape[:2], 1, 1) # BxCxNxM
a_zeroed[mask.expand(-1, -1, -1, a.shape[-1])] = 0
at_a = a_zeroed.transpose(-2, -1) @ a_zeroed # BxCxMxM
inpt_zeroed = inpt.clone()
inpt_zeroed[mask] = 0
at_input = a_zeroed.transpose(-2, -1) @ inpt_zeroed # BxCxMx1
result = torch.linalg.solve(at_a, at_input)
return result.squeeze(-1).permute(2, 0, 1) # MxBxC
The batched solution is quite similar to the answer that I posted to your previous question. There are two caveats though:
As we need a different matrix A
and thus A.T @ A
for each input vector now, we end up with a tensor at_a
of size 500×500×100×100 in your given example. This is huge (a tensor of 2.5 billion elements in this case). In my case, it doesn't fit on the GPU, so what I had to do is process the input tensor in chunks:
chunk_size = 50 # TODO: adjust chunk size for your hardware
for lo in range(0, input_tensor.shape[-1], chunk_size):
chunk_result = solver_batch(input_tensor[..., lo:lo+chunk_size], A)
host_tensor[..., lo:lo+chunk_size] = chunk_result
This is still much faster than processing the input element-wise though.
I tried to sanity-check results with the following for-loop, similar to the sanity check in my previous answer:
for i in range(host_tensor.shape[1]):
for j in range(host_tensor.shape[2]):
input_vec = input_tensor[..., i, j]
res_vec = host_tensor[..., i, j]
mask = ~torch.isnan(input_vec)
M = A[mask]
assert torch.allclose((M.T @ M) @ res_vec, M.T @ input_vec[mask], atol=1e-3)
What we check here is that A @ X = B
should hold for X = solve(A, B)
by definition. This, however, seems not to be the case with the given data, neither for mine nor for your approach. I don't know if this is a problem of numerical instabilities (my maths skills are lacking there) or whether I made some stupid mistake.