Background:
I am working on a program that first shifts the different channels of a tensor along the "column" dimension with different distances, and then performs a summation along the "channel" dimension to merge the different dimensions into one. Specifically, given a tensor x of size (B,C,H,W) and step size S, where B, C, H, W represent the batch size, channel number, height, and width, respectively, the i-th channel of x is shifted by distance (i-1)*S, and then the C channels are summed into one.
Here is an 1D toy example. Assume that I have a 3-channel tensor x as
x = torch.tensor(
[[1,1,1],
[2,2,2],
[3,3,3]]
)
Now I set the step size as 1, and then perform a shift on the tensor as
x_shifted = torch.tensor(
[[1,1,1,0,0],
[0,2,2,2,0],
[0,0,3,3,3]]
)
Here, the first channel is shifted by distance 0, the second channel is shifted by distance 1, and the third channel is shifted by distance 2. Finally, all the three channels are summed and merged into one as
y = torch.tensor(
[[1,3,6,5,3]]
)
Question:
I have implemented the original process w.r.t. 2D image tensors in the following code:
import torch
import torch.nn.functional as F
from time import time
#############################################
# Parameters
#############################################
B = 16
C = 28
H = 256
W = 256
S = 2
T = 1000
device = torch.device('cuda')
seed = 2023
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
#############################################
# Method 1
#############################################
alpha = torch.zeros(B, 1, 1, W+(C-1)*S, device=device)
for i in range(C):
alpha[..., (i*S):(i*S+W)] += 1
def A(x, mask):
z = x * mask
y = torch.zeros(B, 1, H, W+(C-1)*S, device=x.device)
for i in range(C):
y[..., (i*S):(i*S+W)] += z[:, (i):(i+1)]
return y
def A_pinv(y, mask):
z = y / alpha.to(y.device)
x = torch.cat([z[..., (i*S):(i*S+W)] for i in range(C)], dim=1) / mask
return x
#############################################
# Method 2
#############################################
kernel = torch.zeros(1, C, 1, (C-1)*S+1, device=device)
for i in range(C):
kernel[:, C-i-1, :, i*S] = 1
def A_fast(x, mask):
return F.conv2d(x * mask, kernel.to(x.device), padding=(0, (C-1)*S))
def A_pinv_fast(y, mask):
return F.conv_transpose2d(y / alpha.to(y.device), kernel, padding=(0, (C-1)*S)) / mask
#############################################
# Test 1
#############################################
start_time = time()
MAE = 0
for i in range(T):
x = torch.rand(B, C, H, W, device=device)
mask = torch.rand(1, 1, H, W, device=device)
mask[mask == 0] = 1e-12
y = A(x, mask)
x_init = A_pinv(y, mask)
y_init = A(x_init, mask)
MAE += (y_init - y).abs().mean().item()
MAE /= T
end_time = time()
print('---')
print('Test 1')
print('Running Time:', end_time - start_time)
print('MAE:', MAE)
#############################################
# Test 2
#############################################
start_time = time()
MAE = 0
for i in range(T):
x = torch.rand(B, C, H, W, device=device)
mask = torch.rand(1, 1, H, W, device=device)
mask[mask == 0] = 1e-12
y = A_fast(x, mask)
x_init = A_pinv_fast(y, mask)
y_init = A_fast(x_init, mask)
MAE += (y_init - y).abs().mean().item()
MAE /= T
end_time = time()
print('---')
print('Test 2')
print('Running Time:', end_time - start_time)
print('MAE:', MAE)
Here, Method 1
implements the process with a for
loop, while I believe that Method 2
implements the process equivalently by using a 2D convolution operation.
To be more specific, functions A
and A_pinv
realize the forwarding compress process and its "pseudo-inverse", respectively. Their "fast" versions in Method 2
are expected to be faster with a parallelized implementation.
However, when I run the code, I find that the Method 1
is still much faster than the Method 2
with large speed leading. I want to ask that:
Can we effectively accelerate the Method 1
? To be more specific, I wonder if we can parallelize the for
loops, to make the "Shift+Summation" process faster?
Large-kernel convolutions are not necessarily efficient.
torch.scatter_add_
can sum over the shifted elements directly.
I didn't write the pseudo inverse (I think it was to check for correctness? I compared this new method with your Method1/Method2).
out_W = W + (C-1)*S
i_list = torch.arange(C, dtype=torch.long, device=device)
y_list = torch.arange(H, dtype=torch.long, device=device)
x_list = torch.arange(W, dtype=torch.long, device=device)
indices = x_list + i_list.view(C, 1, 1)*S + y_list.view(1, H, 1)*(out_W)
indices = indices.view(1, C*H*W).expand(B, C*H*W)
"""
functionally equivalent to:
for i in range(C):
for y in range(H):
for x in range(W):
indices[i*H*W+y*W+x] = x + i*S + y*(out_W)
"""
def A_faster(x, mask):
y = torch.zeros(B, H*out_W, device=x.device)
y.scatter_add_(1, indices, (x*mask).view(B, C*H*W))
return y.view(B, 1, H, out_W)
Surprisingly, your method 1 holds up well even for larger C
(or scatter does not scale well).
For C=28
:
---
Test 1
Running Time: 1.4626126289367676
---
Test 2
Running Time: 2.808514356613159
---
Test 3
Running Time: 1.3663663864135742
---
|Test1 - Test2|: tensor(9.2172e-07, device='cuda:0')
---
|Test1 - Test3|: tensor(7.5425e-09, device='cuda:0')
---
|Test2 - Test3|: tensor(9.2173e-07, device='cuda:0')
For C=512
(method 2 skipped as it is too slow):
---
Test 1
Running Time: 27.37247085571289
---
Test 3
Running Time: 24.335933446884155
---
|Test1 - Test3|: tensor(3.9411e-08, device='cuda:0')
Full testing code:
import torch
import torch.nn.functional as F
from time import time
#############################################
# Parameters
#############################################
B = 16
C = 28
H = 256
W = 256
S = 2
T = 1000
device = torch.device('cuda')
seed = 2023
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
#############################################
# Method 1
#############################################
alpha = torch.zeros(B, 1, 1, W+(C-1)*S, device=device)
for i in range(C):
alpha[..., (i*S):(i*S+W)] += 1
def A(x, mask):
z = x * mask
y = torch.zeros(B, 1, H, W+(C-1)*S, device=x.device)
for i in range(C):
y[..., (i*S):(i*S+W)] += z[:, (i):(i+1)]
return y
def A_pinv(y, mask):
z = y / alpha.to(y.device)
x = torch.cat([z[..., (i*S):(i*S+W)] for i in range(C)], dim=1) / mask
return x
#############################################
# Method 2
#############################################
kernel = torch.zeros(1, C, 1, (C-1)*S+1, device=device)
for i in range(C):
kernel[:, C-i-1, :, i*S] = 1
def A_fast(x, mask):
return F.conv2d(x * mask, kernel.to(x.device), padding=(0, (C-1)*S))
def A_pinv_fast(y, mask):
return F.conv_transpose2d(y / alpha.to(y.device), kernel, padding=(0, (C-1)*S)) / mask
#############################################
# Method 3
#############################################
out_W = W + (C-1)*S
i_list = torch.arange(C, dtype=torch.long, device=device)
y_list = torch.arange(H, dtype=torch.long, device=device)
x_list = torch.arange(W, dtype=torch.long, device=device)
indices = x_list + i_list.view(C, 1, 1)*S + y_list.view(1, H, 1)*(out_W)
indices = indices.view(1, C*H*W).expand(B, C*H*W)
"""
functionally equivalent to:
for i in range(C):
for y in range(H):
for x in range(W):
indices[i*H*W+y*W+x] = x + i*S + y*(out_W)
"""
def A_faster(x, mask):
y = torch.zeros(B, H*out_W, device=x.device)
y.scatter_add_(1, indices, (x*mask).view(B, C*H*W))
return y.view(B, 1, H, out_W)
#############################################
# Test 1
#############################################
torch.cuda.synchronize()
start_time = time()
for i in range(T):
x = torch.rand(B, C, H, W, device=device)
mask = torch.rand(1, 1, H, W, device=device)
mask[mask == 0] = 1e-12
y = A(x, mask)
torch.cuda.synchronize()
end_time = time()
print('---')
print('Test 1')
print('Running Time:', end_time - start_time)
#############################################
# Test 2
#############################################
torch.cuda.synchronize()
start_time = time()
for i in range(T):
x = torch.rand(B, C, H, W, device=device)
mask = torch.rand(1, 1, H, W, device=device)
mask[mask == 0] = 1e-12
y = A_fast(x, mask)
torch.cuda.synchronize()
end_time = time()
print('---')
print('Test 2')
print('Running Time:', end_time - start_time)
#############################################
# Test 3
#############################################
torch.cuda.synchronize()
start_time = time()
for i in range(T):
x = torch.rand(B, C, H, W, device=device)
mask = torch.rand(1, 1, H, W, device=device)
mask[mask == 0] = 1e-12
y = A_faster(x, mask)
torch.cuda.synchronize()
end_time = time()
print('---')
print('Test 3')
print('Running Time:', end_time - start_time)
error = 0
for _ in range(T):
error += (A(x, mask) - A_fast(x, mask)).abs().mean()
error /= T
print('---')
print('|Test1 - Test2|: ', error)
error = 0
for _ in range(T):
error += (A(x, mask) - A_faster(x, mask)).abs().mean()
error /= T
print('---')
print('|Test1 - Test3|: ', error)
error = 0
for _ in range(T):
error += (A_fast(x, mask) - A_faster(x, mask)).abs().mean()
error /= T
print('---')
print('|Test2 - Test3|: ', error)