Search code examples
pythonimage-processinggraphoverlap

How to code up an image stitching software for these 'simple' images?


TLDR: Need help trying to calculate overlap region between 2 graphs.

So I'm trying to stitch these 2 images:
First image

Second image

Since I know that the images I will be stitching definitely come from the same image, I feel that I should be able to code this up myself. Using libraries like OpenCV feels a little like overkill for me for this task.

My current idea is that I can simplify this task by doing the following steps for each image:

  1. Load image using PIL
  2. Convert image to black and white (PIL image mode “L”)
  3. [Optional: crop images to overlapping region by inspection by eye]
  4. Create vector row_sum, which is a sum of each row
  5. [Optional: log row_sum, to reduce the size of values we're working with]
  6. Plot row_sum.

This would reduce the (potentially) (3*2)-dimensional problem, with 3 RGB channels for each pixel on the 2D image to a (1*2)-D problem with the black and white pixel for the 2D image instead. Then, summing across the rows reduces this to a 1D problem.

I used the following code to implement the above:

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

class Stitcher():
    def combine_2(self, img1, img2):
        # thr1, thr2 = self.get_cropped_bw(img1, 115, img2, 80)
        thr1, thr2 = self.get_cropped_bw(img1, 0, img2, 0)
        
        row_sum1 = np.log(thr1.sum(1))
        row_sum2 = np.log(thr2.sum(1))
        
        self.plot_4x4(thr1, thr2, row_sum1, row_sum2)
    
    def get_cropped_bw(self, img1, img1_keep_from, img2, img2_keep_till):    
        im1 = Image.open(img1).convert("L")
        im2 = Image.open(img2).convert("L")
        
        data1 = (np.array(im1)[img1_keep_from:] 
                if img1_keep_from != 0 else np.array(im1))
        data2 = (np.array(im2)[:img2_keep_till] 
                if img2_keep_till != 0 else np.array(im2))
        
        return data1, data2
    
    def plot_4x4(self, thr1, thr2, row_sum1, row_sum2):
        fig, ax = plt.subplots(2, 2, sharey="row", constrained_layout=True)
        
        ax[0, 0].imshow(thr1, cmap="Greys")
        ax[0, 1].imshow(thr2, cmap="Greys")
        
        ax[1, 0].plot(row_sum1, "k.")
        ax[1, 1].plot(row_sum2, "r.")
        
        ax[1, 0].set(
            xlabel="Index Value",
            ylabel="Row Sum",
        )
        
        plt.show()


imgs = (r"combine\imgs\test_image_part_1.jpg",
        r"combine\imgs\test_image_part_2.jpg")

s = Stitcher()
s.combine_2(*imgs)

This gave me this graph:
Resulting Graph

(I've added in those yellow boxes, to indicate the overlap regions.)

This is the bit I'm stuck at. I want to find exactly:

  1. the index value of the left-side of the yellow box for the 1st image and
  2. the index value of the right-side of the yellow box for the 2nd image.

I define the overlap region as the longest range for which the end of the 1st graph 'matches' the start of the 2nd graph. For the method to find the overlap region, what should I do if the row sum values aren't exactly the same (what if one is the other scaled by some factor)?

I feel like this could be a problem that could use dot products to find the similarity between the 2 graphs? But I can't think of how to implement this.


Solution

  • I had a lot more fun with this than I expected. I wrote this using opencv, but that's just to load and show the image. Everything else is done with numpy so swapping this to PIL shouldn't be too difficult.

    I'm using a brute-force matcher. I also wrote a random-start hillclimber that runs in much less time, but I can't guarantee it'll find the correct answer since the gradient space isn't smooth. I won't include it in my code since it's long and janky, but if you really need the time efficiency I can add it back in later.

    I added a random crop and some salt and pepper noise to the images to test for robustness.

    The brute-force matcher operates on the idea that we don't know which section of the two images overlap, so we need to convolve the smaller image over the larger image from left to right, top to bottom. This means our search space is:

    horizontal = small_width + big_width
    vertical = small_height + big_height
    area = horizontal * vertical
    

    This will grow very quickly with image size. I motivate the algorithm by giving it points for having a larger overlap, but it loses more points for having differences in color for the overlapped area.

    Here are some pictures from an execution of this program

    enter image description here

    enter image description here

    enter image description here

    enter image description here

    import cv2
    import numpy as np
    import random
    
    # randomly snips edges
    def randCrop(image, maxMargin):
        c = [random.randint(0,maxMargin) for a in range(4)];
        return image[c[0]:-c[1], c[2]:-c[3]];
    
    # adds noise to image
    def saltPepper(image, minNoise, maxNoise):
        h,w = image.shape;
        randNum = random.randint(minNoise, maxNoise);
        for a in range(randNum):
            x = random.randint(0, w-1);
            y = random.randint(0, h-1);
            image[y,x] = random.randint(0, 255);
        return image;
    
    # evaluate layout
    def getScore(one, two):
        # do raw subtraction
        left = one - two;
        right = two - one;
        sub = np.minimum(left, right);
        return np.count_nonzero(sub);
    
    # return 2d random position within range
    def randPos(img, big_shape):
        th,tw = big_shape;
        h,w = img.shape;
        x = random.randint(0, tw - w);
        y = random.randint(0, th - h);
        return [x,y];
    
    # overlays small image onto big image
    def overlay(small, big, pos):
        # unpack
        h,w = small.shape;
        x,y = pos;
    
        # copy and place
        copy = big.copy();
        copy[y:y+h, x:x+w] = small;
        return copy;
    
    # calculates overlap region
    def overlap(one, two, pos_one, pos_two):
        # unpack
        h1,w1 = one.shape;
        h2,w2 = two.shape;
        x1,y1 = pos_one;
        x2,y2 = pos_two;
    
        # set edges
        l1 = x1;
        l2 = x2;
        r1 = x1 + w1;
        r2 = x2 + w2;
        t1 = y1;
        t2 = y2;
        b1 = y1 + h1;
        b2 = y2 + h2;
    
        # go
        left = max(l1, l2);
        right = min(r1, r2);
        top = max(t1, t2);
        bottom = min(b1, b2);
        return [left, right, top, bottom];
    
    # wrapper for overlay + getScore
    def fullScore(one, two, pos_one, pos_two, big_empty):
        # check positions
        x,y = pos_two;
        h,w = two.shape;
        th,tw = big_empty.shape;
        if y+h > th or x+w > tw or x < 0 or y < 0:
            return -99999999;
    
        # overlay
        temp_one = overlay(one, big_empty, pos_one);
        temp_two = overlay(two, big_empty, pos_two);
    
        # get overlap
        l,r,t,b = overlap(one, two, pos_one, pos_two);
        temp_one = temp_one[t:b, l:r];
        temp_two = temp_two[t:b, l:r];
    
        # score
        diff = getScore(temp_one, temp_two);
        score = (r-l) * (b-t);
        score -= diff*2;
        return score;
    
    # do brute force
    def bruteForce(one, two):
        # calculate search space
        # unpack size
        h,w = one.shape;
        one_size = h*w;
        h,w = two.shape;
        two_size = h*w;
    
        # small and big
        if one_size < two_size:
            small = one;
            big = two;
        else:
            small = two;
            big = one;
    
        # unpack size
        sh, sw = small.shape;
        bh, bw = big.shape;
        total_width = bw + sw * 2;
        total_height = bh + sh * 2;
    
        # set up empty images
        empty = np.zeros((total_height, total_width), np.uint8);
        
        # set global best
        best_score = -999999;
        best_pos = None;
    
        # start scrolling
        ybound = total_height - sh;
        xbound = total_width - sw;
        for y in range(ybound):
            print("y: " + str(y) + " || " + str(empty.shape));
            for x in range(xbound):
                # get score
                score = fullScore(big, small, [sw,sh], [x,y], empty);
    
                # show
                # prog = overlay(big, empty, [sw,sh]);
                # prog = overlay(small, prog, [x,y]);
                # cv2.imshow("prog", prog);
                # cv2.waitKey(1);
    
                # compare
                if score > best_score:
                    best_score = score;
                    best_pos = [x,y];
                    print("best_score: " + str(best_score));
        return best_pos, [sw,sh], small, big, empty;
    
    # do a step of hill climber
    def hillStep(one, two, best_pos, big_empty, step):
        # make a step
        new_pos = best_pos[1][:];
        new_pos[0] += step[0];
        new_pos[1] += step[1];
    
        # get score
        return fullScore(one, two, best_pos[0], new_pos, big_empty), new_pos;
    
    # hunt around for good position
    # let's do a random-start hillclimber
    def randHill(one, two, shape):
        # set up empty images
        big_empty = np.zeros(shape, np.uint8);
    
        # set global best
        g_best_score = -999999;
        g_best_pos = None;
    
        # lets do 200 iterations
        iters = 200;
        for a in range(iters):
            # progress check
            print(str(a) + " of " + str(iters));
    
            # start with random position
            h,w = two.shape[:2];
            pos_one = [w,h];
            pos_two = randPos(two, shape);
    
            # get score
            best_score = fullScore(one, two, pos_one, pos_two, big_empty);
            best_pos = [pos_one, pos_two];
    
            # hill climb (only on second image)
            while True:
                # end condition: no step improves score
                end_flag = True;
    
                # 8-way
                for y in range(-1, 1+1):
                    for x in range(-1, 1+1):
                        if x != 0 or y != 0:
                            # get score and update
                            score, new_pos = hillStep(one, two, best_pos, big_empty, [x,y]);
                            if score > best_score:
                                best_score = score;
                                best_pos[1] = new_pos[:];
                                end_flag = False;
    
                # end
                if end_flag:
                    break;
                else:
                    # show
                    # prog = overlay(one, big_empty, best_pos[0]);
                    # prog = overlay(two, prog, best_pos[1]);
                    # cv2.imshow("prog", prog);
                    # cv2.waitKey(1);
                    pass;
    
            # check for new global best
            if best_score > g_best_score:
                g_best_score = best_score;
                g_best_pos = best_pos[:];
                print("top score: " + str(g_best_score));
        return g_best_score, g_best_pos;
    
    # load both images
    top = cv2.imread("top.jpg");
    bottom = cv2.imread("bottom.jpg");
    top = cv2.cvtColor(top, cv2.COLOR_BGR2GRAY);
    bottom = cv2.cvtColor(bottom, cv2.COLOR_BGR2GRAY);
    
    # randomly crop
    top = randCrop(top, 20);
    bottom = randCrop(bottom, 20);
    
    # randomly add noise
    saltPepper(top, 200, 1000);
    saltPepper(bottom, 200, 1000);
    
    # set up max image (assume no overlap whatsoever)
    tw = 0;
    th = 0;
    h, w = top.shape;
    tw += w;
    th += h;
    h, w = bottom.shape;
    tw += w*2;
    th += h*2;
    
    # do random-start hill climb
    _, best_pos = randHill(top, bottom, (th, tw));
    
    # show
    empty = np.zeros((th, tw), np.uint8);
    pos1, pos2 = best_pos;
    image = overlay(top, empty, pos1);
    image = overlay(bottom, image, pos2);
    
    # do brute force
    # small_pos, big_pos, small, big, empty = bruteForce(top, bottom);
    # image = overlay(big, empty, big_pos);
    # image = overlay(small, image, small_pos);
    
    # recolor overlap
    h,w = empty.shape;
    color = np.zeros((h,w,3), np.uint8);
    l,r,t,b = overlap(top, bottom, pos1, pos2);
    color[:,:,0] = image;
    color[:,:,1] = image;
    color[:,:,2] = image;
    color[t:b, l:r, 0] += 100;
    
    # show images
    cv2.imshow("top", top);
    cv2.imshow("bottom", bottom);
    cv2.imshow("overlayed", image);
    cv2.imshow("Color", color);
    cv2.waitKey(0);
    

    Edit: I added in the random-start hillclimber