Search code examples
pythonpytorchparallel-processinggputensor

Parallelize & accelerate loops of tensor additions


Background:

I am working on a program that first shifts the different channels of a tensor along the "column" dimension with different distances, and then performs a summation along the "channel" dimension to merge the different dimensions into one. Specifically, given a tensor x of size (B,C,H,W) and step size S, where B, C, H, W represent the batch size, channel number, height, and width, respectively, the i-th channel of x is shifted by distance (i-1)*S, and then the C channels are summed into one.

Here is an 1D toy example. Assume that I have a 3-channel tensor x as

x = torch.tensor(
[[1,1,1],
[2,2,2],
[3,3,3]]
)

Now I set the step size as 1, and then perform a shift on the tensor as

x_shifted = torch.tensor(
[[1,1,1,0,0],
[0,2,2,2,0],
[0,0,3,3,3]]
)

Here, the first channel is shifted by distance 0, the second channel is shifted by distance 1, and the third channel is shifted by distance 2. Finally, all the three channels are summed and merged into one as

y = torch.tensor(
[[1,3,6,5,3]]
)

Question:

I have implemented the original process w.r.t. 2D image tensors in the following code:

import torch
import torch.nn.functional as F
from time import time

#############################################
# Parameters
#############################################

B = 16
C = 28
H = 256
W = 256
S = 2
T = 1000
device = torch.device('cuda')

seed = 2023
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

#############################################
# Method 1
#############################################

alpha = torch.zeros(B, 1, 1, W+(C-1)*S, device=device)
for i in range(C):
    alpha[..., (i*S):(i*S+W)] += 1

def A(x, mask):
    z = x * mask
    y = torch.zeros(B, 1, H, W+(C-1)*S, device=x.device)
    for i in range(C):
        y[..., (i*S):(i*S+W)] += z[:, (i):(i+1)]
    return y

def A_pinv(y, mask):
    z = y / alpha.to(y.device)
    x = torch.cat([z[..., (i*S):(i*S+W)] for i in range(C)], dim=1) / mask
    return x

#############################################
# Method 2
#############################################

kernel = torch.zeros(1, C, 1, (C-1)*S+1, device=device)
for i in range(C):
    kernel[:, C-i-1, :, i*S] = 1

def A_fast(x, mask):
    return F.conv2d(x * mask, kernel.to(x.device), padding=(0, (C-1)*S))

def A_pinv_fast(y, mask):
    return F.conv_transpose2d(y / alpha.to(y.device), kernel, padding=(0, (C-1)*S)) / mask

#############################################
# Test 1
#############################################
start_time = time()
MAE = 0
for i in range(T):
    x = torch.rand(B, C, H, W, device=device)
    mask = torch.rand(1, 1, H, W, device=device)
    mask[mask == 0] = 1e-12
    y = A(x, mask)
    x_init = A_pinv(y, mask)
    y_init = A(x_init, mask)
    MAE += (y_init - y).abs().mean().item()
MAE /= T
end_time = time()
print('---')
print('Test 1')
print('Running Time:', end_time - start_time)
print('MAE:', MAE)

#############################################
# Test 2
#############################################
start_time = time()
MAE = 0
for i in range(T):
    x = torch.rand(B, C, H, W, device=device)
    mask = torch.rand(1, 1, H, W, device=device)
    mask[mask == 0] = 1e-12
    y = A_fast(x, mask)
    x_init = A_pinv_fast(y, mask)
    y_init = A_fast(x_init, mask)
    MAE += (y_init - y).abs().mean().item()
MAE /= T
end_time = time()
print('---')
print('Test 2')
print('Running Time:', end_time - start_time)
print('MAE:', MAE)

Here, Method 1 implements the process with a for loop, while I believe that Method 2 implements the process equivalently by using a 2D convolution operation.

To be more specific, functions A and A_pinv realize the forwarding compress process and its "pseudo-inverse", respectively. Their "fast" versions in Method 2 are expected to be faster with a parallelized implementation.

However, when I run the code, I find that the Method 1 is still much faster than the Method 2 with large speed leading. I want to ask that:

Can we effectively accelerate the Method 1? To be more specific, I wonder if we can parallelize the for loops, to make the "Shift+Summation" process faster?


Solution

  • Large-kernel convolutions are not necessarily efficient. torch.scatter_add_ can sum over the shifted elements directly.

    I didn't write the pseudo inverse (I think it was to check for correctness? I compared this new method with your Method1/Method2).

    out_W = W + (C-1)*S
    i_list = torch.arange(C, dtype=torch.long, device=device)
    y_list = torch.arange(H, dtype=torch.long, device=device)
    x_list = torch.arange(W, dtype=torch.long, device=device)
    indices = x_list + i_list.view(C, 1, 1)*S + y_list.view(1, H, 1)*(out_W)
    indices = indices.view(1, C*H*W).expand(B, C*H*W)
    """
    functionally equivalent to:
    for i in range(C):
        for y in range(H):
            for x in range(W):
                indices[i*H*W+y*W+x] = x + i*S + y*(out_W)
    """
    
    def A_faster(x, mask):
        y = torch.zeros(B, H*out_W, device=x.device)
        y.scatter_add_(1, indices, (x*mask).view(B, C*H*W))
        return y.view(B, 1, H, out_W)
    

    Surprisingly, your method 1 holds up well even for larger C (or scatter does not scale well).

    For C=28:

    ---   
    Test 1                          
    Running Time: 1.4626126289367676  
    ---                                                                                                      
    Test 2                   
    Running Time: 2.808514356613159                                                                          
    ---                                                                                                      
    Test 3                                                                                                   
    Running Time: 1.3663663864135742                
    ---                                                                                                      
    |Test1 - Test2|:  tensor(9.2172e-07, device='cuda:0')
    ---   
    |Test1 - Test3|:  tensor(7.5425e-09, device='cuda:0')
    ---                               
    |Test2 - Test3|:  tensor(9.2173e-07, device='cuda:0')
    

    For C=512 (method 2 skipped as it is too slow):

    ---
    Test 1
    Running Time: 27.37247085571289
    ---
    Test 3
    Running Time: 24.335933446884155
    ---
    |Test1 - Test3|:  tensor(3.9411e-08, device='cuda:0')
    

    Full testing code:

    import torch
    import torch.nn.functional as F
    from time import time
    
    #############################################
    # Parameters
    #############################################
    
    B = 16
    C = 28
    H = 256
    W = 256
    S = 2
    T = 1000
    device = torch.device('cuda')
    
    seed = 2023
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    
    #############################################
    # Method 1
    #############################################
    
    alpha = torch.zeros(B, 1, 1, W+(C-1)*S, device=device)
    for i in range(C):
        alpha[..., (i*S):(i*S+W)] += 1
    
    def A(x, mask):
        z = x * mask
        y = torch.zeros(B, 1, H, W+(C-1)*S, device=x.device)
        for i in range(C):
            y[..., (i*S):(i*S+W)] += z[:, (i):(i+1)]
        return y
    
    def A_pinv(y, mask):
        z = y / alpha.to(y.device)
        x = torch.cat([z[..., (i*S):(i*S+W)] for i in range(C)], dim=1) / mask
        return x
    
    #############################################
    # Method 2
    #############################################
    
    kernel = torch.zeros(1, C, 1, (C-1)*S+1, device=device)
    for i in range(C):
        kernel[:, C-i-1, :, i*S] = 1
    
    def A_fast(x, mask):
        return F.conv2d(x * mask, kernel.to(x.device), padding=(0, (C-1)*S))
    
    def A_pinv_fast(y, mask):
        return F.conv_transpose2d(y / alpha.to(y.device), kernel, padding=(0, (C-1)*S)) / mask
    
    
    #############################################
    # Method 3
    #############################################
    out_W = W + (C-1)*S
    i_list = torch.arange(C, dtype=torch.long, device=device)
    y_list = torch.arange(H, dtype=torch.long, device=device)
    x_list = torch.arange(W, dtype=torch.long, device=device)
    indices = x_list + i_list.view(C, 1, 1)*S + y_list.view(1, H, 1)*(out_W)
    indices = indices.view(1, C*H*W).expand(B, C*H*W)
    """
    functionally equivalent to:
    for i in range(C):
        for y in range(H):
            for x in range(W):
                indices[i*H*W+y*W+x] = x + i*S + y*(out_W)
    """
    
    def A_faster(x, mask):
        y = torch.zeros(B, H*out_W, device=x.device)
        y.scatter_add_(1, indices, (x*mask).view(B, C*H*W))
        return y.view(B, 1, H, out_W)
    
    
    #############################################
    # Test 1
    #############################################
    torch.cuda.synchronize()
    start_time = time()
    for i in range(T):
        x = torch.rand(B, C, H, W, device=device)
        mask = torch.rand(1, 1, H, W, device=device)
        mask[mask == 0] = 1e-12
        y = A(x, mask)
    torch.cuda.synchronize()
    end_time = time()
    print('---')
    print('Test 1')
    print('Running Time:', end_time - start_time)
    
    #############################################
    # Test 2
    #############################################
    torch.cuda.synchronize()
    start_time = time()
    for i in range(T):
        x = torch.rand(B, C, H, W, device=device)
        mask = torch.rand(1, 1, H, W, device=device)
        mask[mask == 0] = 1e-12
        y = A_fast(x, mask)
    torch.cuda.synchronize()
    end_time = time()
    print('---')
    print('Test 2')
    print('Running Time:', end_time - start_time)
    
    #############################################
    # Test 3
    #############################################
    torch.cuda.synchronize()
    start_time = time()
    for i in range(T):
        x = torch.rand(B, C, H, W, device=device)
        mask = torch.rand(1, 1, H, W, device=device)
        mask[mask == 0] = 1e-12
        y = A_faster(x, mask)
    torch.cuda.synchronize()
    end_time = time()
    print('---')
    print('Test 3')
    print('Running Time:', end_time - start_time)
    
    
    error = 0
    for _ in range(T):
        error += (A(x, mask) - A_fast(x, mask)).abs().mean()
    error /= T
    print('---')
    print('|Test1 - Test2|: ', error)
    
    error = 0
    for _ in range(T):
        error += (A(x, mask) - A_faster(x, mask)).abs().mean()
    error /= T
    print('---')
    print('|Test1 - Test3|: ', error)
    
    error = 0
    for _ in range(T):
        error += (A_fast(x, mask) - A_faster(x, mask)).abs().mean()
    error /= T
    print('---')
    print('|Test2 - Test3|: ', error)