I have created code that blends tensors together, by first blending them together into rows and then blending the rows together into the final output. It works well for 4 tensors in a 2x2 pattern, but it fails to do a 2x3 (6 tensors), 3x3 (9 tensors), 4x4 (16 tensors) pattern.
The tensors are in the form of (B x C x H x W), where B is batch size, C is channels, H is height, and W is width.
For both the tiles to rows (tile_overlay()
), and rows to final image (row_overlay()
), I create a base tensor that I add the tiles/rows to. I suspect the issue with my code lies either with how I get the base tensor's dimensions, how I track where to put the rows/tiles on the base tensor, or the issue with both of those things.
import torch
from PIL import Image
import torchvision.transforms as transforms
def preprocess(image_name, image_size):
image = Image.open(image_name).convert('RGB')
if type(image_size) is not tuple:
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
tensor = (Loader(image) * 256).unsqueeze(0)
return tensor
def deprocess(output_tensor):
output_tensor = output_tensor.clone().squeeze(0).cpu() / 256
output_tensor.clamp_(0, 1)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor.cpu())
return image
def prepare_tile(tile, overlap, side='both'):
lin_mask_left = torch.linspace(0,1,overlap).repeat(tile.size(3),1).repeat(3,1,1).unsqueeze(0)
lin_mask_right = torch.linspace(1,0,overlap).repeat(tile.size(3),1).repeat(3,1,1).unsqueeze(0)
if side == 'both' or side == 'right':
tile[:,:,:,overlap:] = tile[:,:,:,overlap:] * lin_mask_right
if side == 'both' or side == 'left':
tile[:,:,:,:overlap] = tile[:,:,:,:overlap] * lin_mask_left
return tile
def overlay_tiles(tile_list, rows, overlap):
c = 1
f_tiles = []
base_length = 0
for i, tile in enumerate(tile_list):
if c == 1:
f_tile = prepare_tile(tile.clone(), overlap, side='right')
if i + 1<= rows[1]:
base_length += tile.clone().size(3) - overlap
elif c == rows[1]:
f_tile = prepare_tile(tile.clone(), overlap, side='left')
if i + 1<= rows[1]:
base_length += tile.size(3) - overlap
elif c > 0 and c < rows[1]:
f_tile = prepare_tile(tile.clone(), overlap, side='both')
if i + 1<= rows[1]:
base_length += tile.size(3) - (overlap*2)
f_tiles.append(f_tile)
if c == rows[1]:
c = 0
c+=1
base_length += overlap
base_tensor = torch.zeros(3, tile_list[0].size(2), base_length).unsqueeze(0)
row_list = []
for row in range(rows[1]):
row_list.append(base_tensor.clone())
row_val, num_tiles = 0, 0
l_max = tile_list[0].size(3)
for y in range(rows[0]):
for x in range(rows[1]):
if num_tiles % rows[1] != 0:
l_max += (f_tiles[num_tiles].size(3)-overlap)*x
l_min = l_max - f_tiles[num_tiles].size(3)
row_list[row_val][:, :, :, l_min:l_max] = row_list[row_val][:, :, :, l_min:l_max] + f_tiles[num_tiles]
else:
row_list[row_val][:, :, :, :f_tiles[num_tiles].size(3)] = f_tiles[num_tiles]
l_max = tile_list[0].size(3)
num_tiles+=1
row_val+=1
return row_list
def prepare_row(row_tensor, overlap, side='both'):
lin_mask_top = torch.linspace(0,1,overlap).repeat(row_tensor.size(3),1).rot90(3).repeat(3,1,1).unsqueeze(0)
lin_mask_bottom = torch.linspace(1,0,overlap).repeat(row_tensor.size(3),1).rot90(3).repeat(3,1,1).unsqueeze(0)
if side == 'both' or side == 'top':
row_tensor[:,:,:overlap,:] = row_tensor[:,:,:overlap,:] * lin_mask_top
if side == 'both' or side == 'bottom':
row_tensor[:,:,overlap:,:] = row_tensor[:,:,overlap:,:] * lin_mask_bottom
return row_tensor
def overlay_rows(row_list, rows, overlap):
c = 1
f_rows = []
base_height = 0
for i, row_tensor in enumerate(row_list):
if c == 1:
f_row = prepare_row(row_tensor.clone(), overlap, side='bottom')
if i + 1<= rows[0]:
base_height += row_tensor.size(2) - overlap
elif c == rows[1]:
f_row = prepare_row(row_tensor.clone(), overlap, side='top')
if i + 1<= rows[0]:
base_height += row_tensor.size(2) - overlap
elif c > 0 and c < rows[0]:
f_row = prepare_row(row_tensor.clone(), overlap, side='both')
if i + 1<= rows[0]:
base_height += tile.size(2) - (overlap*2)
f_rows.append(f_row)
if c == rows[0]:
c = 0
c+=1
base_height += overlap
base_tensor = torch.zeros(3, base_height, row_list[0].size(3)).unsqueeze(0)
num_rows = 0
l_max = row_list[0].size(3)
for y in range(rows[0]):
if num_rows > 0:
l_max += (f_rows[num_rows].size(2)-overlap)*y
l_min = l_max - f_rows[num_rows].size(2)
base_tensor[:, :, l_min:l_max, :] = base_tensor[:, :, l_min:l_max, :] + f_rows[num_rows]
else:
base_tensor[:, :, :f_rows[num_rows].size(2), :] = f_rows[num_rows]
l_max = row_list[0].size(2)
num_rows+=1
return base_tensor
def rebuild_image(tensor_list, rows, overlap_hw):
row_tensors = overlay_tiles(tensor_list, rows, overlap_hw[1])
full_tensor = overlay_rows(row_tensors, rows, overlap_hw[0])
return full_tensor
test_tensor_1 = preprocess('brad_pitt.jpg', (1080,1080))
test_tensor_2 = preprocess('starry_night_google.jpg', (1080,1080))
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
rows = [2, 2]
overlap = [540, 540]
complete_tensor = rebuild_image(tensor_list, rows, overlap)
ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_2x2.png')
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(),]
rows = [3, 3]
overlap = [540, 540]
complete_tensor = rebuild_image(tensor_list, rows, overlap)
ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_3x3.png')
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
rows = [4, 4]
overlap = [540, 540]
complete_tensor = rebuild_image(tensor_list, rows, overlap)
ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_4x4.png')
Running the above code will result in this error message when trying to create the 3x3 output:
Traceback (most recent call last):
File "t0.py", line 148, in <module>
complete_tensor = rebuild_image(tensor_list, rows, overlap)
File "t0.py", line 126, in rebuild_image
row_tensors = overlay_tiles(tensor_list, rows, overlap_hw[1])
File "t0.py", line 68, in overlay_tiles
row_list[row_val][:, :, :, l_min:l_max] = row_list[row_val][:, :, :, l_min:l_max] + f_tiles[num_tiles]
RuntimeError: The size of tensor a (0) must match the size of tensor b (1080) at non-singleton dimension 3
This is an example of what the 2x2 output looks like:
And this is a visual diagram with two examples of what I am doing:
The code now works with some changes:
import torch
from PIL import Image
import torchvision.transforms as transforms
def preprocess(image_name, image_size):
image = Image.open(image_name).convert('RGB')
if type(image_size) is not tuple:
image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
tensor = (Loader(image) * 256).unsqueeze(0)
return tensor
def deprocess(output_tensor):
output_tensor = output_tensor.clone().squeeze(0).cpu() / 256
output_tensor.clamp_(0, 1)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor.cpu())
return image
def prepare_tile(tile, overlap, side='both'):
h, w = tile.size(2), tile.size(3)
lin_mask_left = torch.linspace(0,1,overlap).repeat(h,1).repeat(3,1,1).unsqueeze(0)
lin_mask_right = torch.linspace(1,0,overlap).repeat(h,1).repeat(3,1,1).unsqueeze(0)
if side == 'both' or side == 'right':
tile[:,:,:,w-overlap:] = tile[:,:,:,w-overlap:] * lin_mask_right
if side == 'both' or side == 'left':
tile[:,:,:,:overlap] = tile[:,:,:,:overlap] * lin_mask_left
return tile
def calc_length(w, overlap, rows):
count = 0
l_max = w
for y in range(rows[0]):
for x in range(rows[1]):
if count % rows[1] != 0:
l_max += w-overlap
l_min = l_max - w
else:
l_max = w
count+=1
return l_max
def overlay_tiles(tile_list, rows, overlap):
c = 1
f_tiles = []
base_length = 0
for i, tile in enumerate(tile_list):
if c == 1:
f_tile = prepare_tile(tile.clone(), overlap, side='right')
elif c == rows[1]:
f_tile = prepare_tile(tile.clone(), overlap, side='left')
elif c > 0 and c < rows[1]:
f_tile = prepare_tile(tile.clone(), overlap, side='both')
f_tiles.append(f_tile)
if c == rows[1]:
c = 0
c+=1
w = tile_list[0].size(3)
base_length = calc_length(w, overlap, rows)
base_tensor = torch.zeros(3, tile_list[0].size(2), base_length).unsqueeze(0)
row_list = []
for row in range(rows[0]):
row_list.append(base_tensor.clone())
row_num, num_tiles = 0, 0
l_max = w
for y in range(rows[0]):
for x in range(rows[1]):
if num_tiles % rows[1] != 0:
l_max += w-overlap
l_min = l_max - w
print(num_tiles, l_max, l_min)
row_list[row_num][:, :, :, l_min:l_max] = row_list[row_num][:, :, :, l_min:l_max] + f_tiles[num_tiles]
else:
row_list[row_num][:, :, :, :w] = f_tiles[num_tiles]
l_max = w
num_tiles+=1
row_num+=1
return row_list
def prepare_row(row_tensor, overlap, side='both'):
lin_mask_top = torch.linspace(0,1,overlap).repeat(row_tensor.size(3),1).rot90(3).repeat(3,1,1).unsqueeze(0)
lin_mask_bottom = torch.linspace(1,0,overlap).repeat(row_tensor.size(3),1).rot90(3).repeat(3,1,1).unsqueeze(0)
if side == 'both' or side == 'top':
row_tensor[:,:,:overlap,:] = row_tensor[:,:,:overlap,:] * lin_mask_top
if side == 'both' or side == 'bottom':
row_tensor[:,:,overlap:,:] = row_tensor[:,:,overlap:,:] * lin_mask_bottom
return row_tensor
def calc_height(h, overlap, rows):
num_rows = 0
l_max = h
for y in range(rows[0]):
if num_rows > 0:
l_max += (h-overlap)
l_min = l_max - h
else:
l_max = h
num_rows+=1
return l_max
def overlay_rows(row_list, rows, overlap):
c = 1
f_rows = []
base_height = 0
for i, row_tensor in enumerate(row_list):
if c == 1:
f_row = prepare_row(row_tensor.clone(), overlap, side='bottom')
elif c == rows[0]:
f_row = prepare_row(row_tensor.clone(), overlap, side='top')
elif c > 0 and c < rows[0]:
f_row = prepare_row(row_tensor.clone(), overlap, side='both')
f_rows.append(f_row)
if c == rows[0]:
c = 0
c+=1
h = row_list[0].size(2)
base_height = calc_height(h, overlap, rows)
base_tensor = torch.zeros(3, base_height, row_list[0].size(3)).unsqueeze(0)
num_rows = 0
l_max = row_list[0].size(3)
for y in range(rows[0]):
if num_rows > 0:
l_max += (h-overlap)
l_min = l_max - h
base_tensor[:, :, l_min:l_max, :] = base_tensor[:, :, l_min:l_max, :] + f_rows[num_rows]
else:
base_tensor[:, :, :h, :] = f_rows[num_rows]
l_max = h
num_rows+=1
return base_tensor
def rebuild_image(tensor_list, rows, overlap_hw):
row_tensors = overlay_tiles(tensor_list, rows, overlap_hw[1])
full_tensor = overlay_rows(row_tensors, rows, overlap_hw[0])
return full_tensor
test_tensor_1 = preprocess('brad_pitt.jpg', (1080,720))
test_tensor_2 = preprocess('starry_night_google.jpg', (1080,720))
print("2x2 Test")
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
rows = [2, 2]
overlap = [540, 260]
complete_tensor = rebuild_image(tensor_list, rows, overlap)
ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_2x2.png')
print("3x3 Test")
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(),]
rows = [3, 3]
overlap = [540, 540]
complete_tensor = rebuild_image(tensor_list, rows, overlap)
ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_3x3.png')
print("3x4 Test")
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
rows = [3, 4]
overlap = [540, 260]
complete_tensor = rebuild_image(tensor_list, rows, overlap)
ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_3x4.png')
print("4x3 Test")
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone()]
rows = [4, 3]
overlap = [540, 260]
complete_tensor = rebuild_image(tensor_list, rows, overlap)
ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_4x3.png')
print("4x4 Test")
tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
rows = [4, 4]
overlap = [540, 260]
complete_tensor = rebuild_image(tensor_list, rows, overlap)
ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_4x4.png')