Search code examples
pythonimage-processingpytorchmasktensor

How do I blend together tensors in patterns other than 2x2?


I have created code that blends tensors together, by first blending them together into rows and then blending the rows together into the final output. It works well for 4 tensors in a 2x2 pattern, but it fails to do a 2x3 (6 tensors), 3x3 (9 tensors), 4x4 (16 tensors) pattern.

The tensors are in the form of (B x C x H x W), where B is batch size, C is channels, H is height, and W is width.

For both the tiles to rows (tile_overlay()), and rows to final image (row_overlay()), I create a base tensor that I add the tiles/rows to. I suspect the issue with my code lies either with how I get the base tensor's dimensions, how I track where to put the rows/tiles on the base tensor, or the issue with both of those things.

import torch
from PIL import Image
import torchvision.transforms as transforms


def preprocess(image_name, image_size):
    image = Image.open(image_name).convert('RGB')
    if type(image_size) is not tuple:
        image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
    tensor = (Loader(image) * 256).unsqueeze(0)
    return tensor

def deprocess(output_tensor):
    output_tensor = output_tensor.clone().squeeze(0).cpu() / 256
    output_tensor.clamp_(0, 1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.cpu())
    return image




def prepare_tile(tile, overlap, side='both'):
    lin_mask_left = torch.linspace(0,1,overlap).repeat(tile.size(3),1).repeat(3,1,1).unsqueeze(0)
    lin_mask_right = torch.linspace(1,0,overlap).repeat(tile.size(3),1).repeat(3,1,1).unsqueeze(0)
    if side == 'both' or side == 'right':
        tile[:,:,:,overlap:] = tile[:,:,:,overlap:] * lin_mask_right
    if side == 'both' or side == 'left':
        tile[:,:,:,:overlap] = tile[:,:,:,:overlap] * lin_mask_left 
    return tile

def overlay_tiles(tile_list, rows, overlap):        
    c = 1
    f_tiles = []
    base_length = 0
    for i, tile in enumerate(tile_list):
        if c == 1:    
             f_tile = prepare_tile(tile.clone(), overlap, side='right')
             if i + 1<= rows[1]:
                 base_length += tile.clone().size(3) - overlap
        elif c == rows[1]:
             f_tile = prepare_tile(tile.clone(), overlap, side='left')
             if i + 1<= rows[1]:
                 base_length += tile.size(3) - overlap
        elif c > 0 and c < rows[1]:
             f_tile = prepare_tile(tile.clone(), overlap, side='both')
             if i + 1<= rows[1]:
                 base_length += tile.size(3) - (overlap*2)
        f_tiles.append(f_tile)  
        if c == rows[1]:
             c = 0
        c+=1

    base_length += overlap           
    base_tensor = torch.zeros(3, tile_list[0].size(2), base_length).unsqueeze(0)
    row_list = []
    for row in range(rows[1]):
        row_list.append(base_tensor.clone())

    row_val, num_tiles = 0, 0
    l_max = tile_list[0].size(3)
    for y in range(rows[0]):       
        for x in range(rows[1]):        
            if num_tiles % rows[1] != 0:
                l_max += (f_tiles[num_tiles].size(3)-overlap)*x
                l_min = l_max - f_tiles[num_tiles].size(3)
                row_list[row_val][:, :, :, l_min:l_max] = row_list[row_val][:, :, :, l_min:l_max] + f_tiles[num_tiles]
            else:
                row_list[row_val][:, :, :, :f_tiles[num_tiles].size(3)] = f_tiles[num_tiles]  
                l_max = tile_list[0].size(3)
            num_tiles+=1 
        row_val+=1  
    return row_list


def prepare_row(row_tensor, overlap, side='both'):
    lin_mask_top = torch.linspace(0,1,overlap).repeat(row_tensor.size(3),1).rot90(3).repeat(3,1,1).unsqueeze(0)
    lin_mask_bottom = torch.linspace(1,0,overlap).repeat(row_tensor.size(3),1).rot90(3).repeat(3,1,1).unsqueeze(0)
    if side == 'both' or side == 'top':
        row_tensor[:,:,:overlap,:] = row_tensor[:,:,:overlap,:]  * lin_mask_top
    if side == 'both' or side == 'bottom':
        row_tensor[:,:,overlap:,:] = row_tensor[:,:,overlap:,:] * lin_mask_bottom   
    return row_tensor

def overlay_rows(row_list, rows, overlap):
    c = 1
    f_rows = []
    base_height = 0
    for i, row_tensor in enumerate(row_list):
        if c == 1:    
             f_row = prepare_row(row_tensor.clone(), overlap, side='bottom')
             if i + 1<= rows[0]:
                 base_height += row_tensor.size(2) - overlap
        elif c == rows[1]:
             f_row = prepare_row(row_tensor.clone(), overlap, side='top')
             if i + 1<= rows[0]:
                 base_height += row_tensor.size(2) - overlap
        elif c > 0 and c < rows[0]:
             f_row = prepare_row(row_tensor.clone(), overlap, side='both')
             if i + 1<= rows[0]:
                 base_height += tile.size(2) - (overlap*2)
        f_rows.append(f_row)    
        if c == rows[0]:
             c = 0
        c+=1

    base_height += overlap           
    base_tensor = torch.zeros(3, base_height, row_list[0].size(3)).unsqueeze(0)

    num_rows = 0
    l_max = row_list[0].size(3)
    for y in range(rows[0]):        
            if num_rows > 0:
                l_max += (f_rows[num_rows].size(2)-overlap)*y
                l_min = l_max - f_rows[num_rows].size(2)
                base_tensor[:, :, l_min:l_max, :] = base_tensor[:, :, l_min:l_max, :] + f_rows[num_rows]
            else:
                base_tensor[:, :, :f_rows[num_rows].size(2), :] = f_rows[num_rows]  
                l_max = row_list[0].size(2)
            num_rows+=1   
    return base_tensor


def rebuild_image(tensor_list, rows, overlap_hw):
    row_tensors = overlay_tiles(tensor_list, rows, overlap_hw[1])
    full_tensor = overlay_rows(row_tensors, rows, overlap_hw[0])
    return full_tensor

test_tensor_1 = preprocess('brad_pitt.jpg', (1080,1080))
test_tensor_2 = preprocess('starry_night_google.jpg', (1080,1080))


tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
rows = [2, 2]
overlap = [540, 540]
complete_tensor = rebuild_image(tensor_list, rows, overlap)

ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_2x2.png')


tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(),]
rows = [3, 3]
overlap = [540, 540]
complete_tensor = rebuild_image(tensor_list, rows, overlap)

ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_3x3.png')



tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
rows = [4, 4]
overlap = [540, 540]
complete_tensor = rebuild_image(tensor_list, rows, overlap)

ft = deprocess(complete_tensor.clone())
ft.save('complete_tensor_4x4.png')

Running the above code will result in this error message when trying to create the 3x3 output:

Traceback (most recent call last):
  File "t0.py", line 148, in <module>
    complete_tensor = rebuild_image(tensor_list, rows, overlap)
  File "t0.py", line 126, in rebuild_image
    row_tensors = overlay_tiles(tensor_list, rows, overlap_hw[1])
  File "t0.py", line 68, in overlay_tiles
    row_list[row_val][:, :, :, l_min:l_max] = row_list[row_val][:, :, :, l_min:l_max] + f_tiles[num_tiles]
RuntimeError: The size of tensor a (0) must match the size of tensor b (1080) at non-singleton dimension 3

This is an example of what the 2x2 output looks like:

enter image description here

And this is a visual diagram with two examples of what I am doing:


Solution

  • The code now works with some changes:

    import torch
    from PIL import Image
    import torchvision.transforms as transforms
    
    
    def preprocess(image_name, image_size):
        image = Image.open(image_name).convert('RGB')
        if type(image_size) is not tuple:
            image_size = tuple([int((float(image_size) / max(image.size))*x) for x in (image.height, image.width)])
        Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
        tensor = (Loader(image) * 256).unsqueeze(0)
        return tensor
    
    def deprocess(output_tensor):
        output_tensor = output_tensor.clone().squeeze(0).cpu() / 256
        output_tensor.clamp_(0, 1)
        Image2PIL = transforms.ToPILImage()
        image = Image2PIL(output_tensor.cpu())
        return image
    
    
    
    
    def prepare_tile(tile, overlap, side='both'):
        h, w = tile.size(2), tile.size(3)
        lin_mask_left = torch.linspace(0,1,overlap).repeat(h,1).repeat(3,1,1).unsqueeze(0)
        lin_mask_right = torch.linspace(1,0,overlap).repeat(h,1).repeat(3,1,1).unsqueeze(0)
        if side == 'both' or side == 'right':
            tile[:,:,:,w-overlap:] = tile[:,:,:,w-overlap:] * lin_mask_right
        if side == 'both' or side == 'left':
            tile[:,:,:,:overlap] = tile[:,:,:,:overlap] * lin_mask_left 
        return tile
    
    
    def calc_length(w, overlap, rows):
        count = 0
        l_max = w
        for y in range(rows[0]):       
            for x in range(rows[1]):        
                if count % rows[1] != 0:
                    l_max += w-overlap
                    l_min = l_max - w
                else:  
                    l_max = w
                count+=1 
        return l_max
    
    def overlay_tiles(tile_list, rows, overlap):        
        c = 1
        f_tiles = []
        base_length = 0
        for i, tile in enumerate(tile_list):
            if c == 1:    
                 f_tile = prepare_tile(tile.clone(), overlap, side='right')
            elif c == rows[1]:
                 f_tile = prepare_tile(tile.clone(), overlap, side='left')
            elif c > 0 and c < rows[1]:
                 f_tile = prepare_tile(tile.clone(), overlap, side='both')
            f_tiles.append(f_tile)  
            if c == rows[1]:
                 c = 0
            c+=1
    
        w = tile_list[0].size(3)
        base_length = calc_length(w, overlap, rows)
        base_tensor = torch.zeros(3, tile_list[0].size(2), base_length).unsqueeze(0)
    
        row_list = []
        for row in range(rows[0]):
            row_list.append(base_tensor.clone())    
    
        row_num, num_tiles = 0, 0
        l_max = w
        for y in range(rows[0]):       
            for x in range(rows[1]):        
                if num_tiles % rows[1] != 0:
                    l_max += w-overlap
                    l_min = l_max - w
                    print(num_tiles, l_max, l_min)
                    row_list[row_num][:, :, :, l_min:l_max] = row_list[row_num][:, :, :, l_min:l_max] + f_tiles[num_tiles]
                else:
                    row_list[row_num][:, :, :, :w] = f_tiles[num_tiles]  
                    l_max = w
                num_tiles+=1 
            row_num+=1  
        return row_list
    
    
    def prepare_row(row_tensor, overlap, side='both'):
        lin_mask_top = torch.linspace(0,1,overlap).repeat(row_tensor.size(3),1).rot90(3).repeat(3,1,1).unsqueeze(0)
        lin_mask_bottom = torch.linspace(1,0,overlap).repeat(row_tensor.size(3),1).rot90(3).repeat(3,1,1).unsqueeze(0)
        if side == 'both' or side == 'top':
            row_tensor[:,:,:overlap,:] = row_tensor[:,:,:overlap,:]  * lin_mask_top
        if side == 'both' or side == 'bottom':
            row_tensor[:,:,overlap:,:] = row_tensor[:,:,overlap:,:] * lin_mask_bottom
        return row_tensor
    
    
    def calc_height(h, overlap, rows):
        num_rows = 0
        l_max = h
        for y in range(rows[0]):        
                if num_rows > 0:
                    l_max += (h-overlap)
                    l_min = l_max - h
                else: 
                    l_max = h
                num_rows+=1  
        return l_max
    
    def overlay_rows(row_list, rows, overlap):
        c = 1
        f_rows = []
        base_height = 0
        for i, row_tensor in enumerate(row_list):
            if c == 1:    
                 f_row = prepare_row(row_tensor.clone(), overlap, side='bottom')
            elif c == rows[0]:
                 f_row = prepare_row(row_tensor.clone(), overlap, side='top')
            elif c > 0 and c < rows[0]:
                 f_row = prepare_row(row_tensor.clone(), overlap, side='both')
            f_rows.append(f_row)    
            if c == rows[0]:
                 c = 0
            c+=1
    
    
        h = row_list[0].size(2)
        base_height = calc_height(h, overlap, rows)          
        base_tensor = torch.zeros(3, base_height, row_list[0].size(3)).unsqueeze(0)
    
        num_rows = 0
        l_max = row_list[0].size(3)
        for y in range(rows[0]):        
                if num_rows > 0:
                    l_max += (h-overlap)
                    l_min = l_max - h
                    base_tensor[:, :, l_min:l_max, :] = base_tensor[:, :, l_min:l_max, :] + f_rows[num_rows]
                else:
                    base_tensor[:, :, :h, :] = f_rows[num_rows]  
                    l_max = h
                num_rows+=1   
        return base_tensor
    
    
    def rebuild_image(tensor_list, rows, overlap_hw):
        row_tensors = overlay_tiles(tensor_list, rows, overlap_hw[1])
        full_tensor = overlay_rows(row_tensors, rows, overlap_hw[0])
        return full_tensor
    
    test_tensor_1 = preprocess('brad_pitt.jpg', (1080,720))
    test_tensor_2 = preprocess('starry_night_google.jpg', (1080,720))
    
    
    
    print("2x2 Test")
    tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
    rows = [2, 2]
    overlap = [540, 260]
    complete_tensor = rebuild_image(tensor_list, rows, overlap)
    
    ft = deprocess(complete_tensor.clone())
    ft.save('complete_tensor_2x2.png')
    
    
    print("3x3 Test")
    tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(), \
    test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(), \
    test_tensor_1.clone(), test_tensor_2.clone(),test_tensor_1.clone(),]
    rows = [3, 3]
    overlap = [540, 540]
    complete_tensor = rebuild_image(tensor_list, rows, overlap)
    
    ft = deprocess(complete_tensor.clone())
    ft.save('complete_tensor_3x3.png')
    
    print("3x4 Test")
    tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
    test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
    test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
    rows = [3, 4]
    overlap = [540, 260]
    complete_tensor = rebuild_image(tensor_list, rows, overlap)
    
    ft = deprocess(complete_tensor.clone())
    ft.save('complete_tensor_3x4.png')
    
    
    print("4x3 Test")
    tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), \
    test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), \
    test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), \
    test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone()]
    rows = [4, 3]
    overlap = [540, 260]
    complete_tensor = rebuild_image(tensor_list, rows, overlap)
    
    ft = deprocess(complete_tensor.clone())
    ft.save('complete_tensor_4x3.png')
    
    print("4x4 Test")
    tensor_list = [test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
    test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
    test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone(), \
    test_tensor_1.clone(), test_tensor_2.clone(), test_tensor_1.clone(), test_tensor_2.clone()]
    rows = [4, 4]
    overlap = [540, 260]
    complete_tensor = rebuild_image(tensor_list, rows, overlap)
    
    ft = deprocess(complete_tensor.clone())
    ft.save('complete_tensor_4x4.png')