Search code examples
pythonnumpyopencvmachine-learningpytorch

Center an image and adding a background at export


I want to automate all of this:

  1. Select an object in an image
  2. Crop my image on this object
  3. Crop to 1:1 aspect ratio, leaving a slight gap around this object
  4. Export my image in JPG format in 800x800px with my object in the center of image and with white background.

I'm on win11 64bit

What I did :

  1. Installing Python and creating an environment
  2. Installingopencv-python-headless, pillow, numpy, Pytorch for use with CUDA 11.8
  3. Clone the repository segment-anything.git and install it with PIP
  4. Download sam_vit_b_01ec64.pth

Coding the py files like that :

import os
import cv2
import numpy as np
from PIL import Image
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

def load_image(image_path):
    return cv2.imread(image_path)

def save_image(image, path):
    cv2.imwrite(path + '.jpg', image)

def select_object(image):
    sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image)
    largest_mask = max(masks, key=lambda x: x['area'])
    return largest_mask['segmentation']

def crop_to_object(image, mask):
    x, y, w, h = cv2.boundingRect(mask.astype(np.uint8))
    padding = 5
    x = max(0, x - padding)
    y = max(0, y - padding)
    w = min(image.shape[1] - x, w + 2 * padding)
    h = min(image.shape[0] - y, h + 2 * padding)
    
    cropped_image = image[y:y+h, x:x+w]
    return cropped_image

def resize_to_square(image, size=800):
    h, w = image.shape[:2]
    scale = size / max(h, w)
    new_h, new_w = int(h * scale), int(w * scale)
    resized_image = cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_LANCZOS4)

    new_image = np.ones((size, size, 3), dtype=np.uint8) * 255

    top = (size - new_h) // 2
    left = (size - new_w) // 2
    bottom = top + new_h
    right = left + new_w

    new_image[top:top+new_h, left:left+new_w] = resized_image

    return new_image

def process_image(image_path, output_path):

    image = load_image(image_path)
    mask = select_object(image)
    cropped_image = crop_to_object(image, mask)
    final_image = resize_to_square(cropped_image, 800)
    save_image(final_image, output_path + '.jpg')

def process_folder(input_folder, output_folder):

    if not os.path.exists(output_folder):
        os.makedirs(output_folder)

    for root, _, files in os.walk(input_folder):
        for filename in files:
            if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                input_path = os.path.join(root, filename)

                relative_path = os.path.relpath(input_path, input_folder)
                output_path = os.path.join(output_folder, relative_path)
                
                output_dir = os.path.dirname(output_path)
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                
                try:
                    process_image(input_path, output_path)
                    print(f"Processed {input_path}")
                except Exception as e:
                    print(f"Failed to process {input_path}: {e}")

if __name__ == "__main__":
    input_folder = ""
    output_folder = ""
    process_folder(input_folder, output_folder)

What happen :

I import Base Image I want Expected result and I obtain Result

There is some different base->result that I had :

May anyone help me to understand what I Missed?

Thanks in advance,

Cyril


Solution

  • I solved the problem. First, I installed matplotlib, and added two functions to my code:

    def visualize_mask(image, mask):
        plt.figure(figsize=(10, 10))
        plt.subplot(1, 2, 1)
        plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        plt.title("Image Originale")
        plt.subplot(1, 2, 2)
        plt.imshow(mask, cmap='gray')
        plt.title("Masque de Segmentation")
        plt.show()
    
    def visualize_cropped_image(image, cropped_image):
        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.imshow(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
        plt.title("Image Originale")
        plt.subplot(1, 2, 2)
        plt.imshow(cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB))
        plt.title("Image Recadrée")
        plt.show()
    

    With that I can see if the mask is correct and if the crop is OK too. This showed that the mask was reversed, the background was selected instead of the subject.

    I solved that by reversing the mask before cropping the image. However, when the base image is a .png with an alpha background, the transparent background was translated to black when converting and cropping the image. That is because OpenCV does not support alpha (transparency) channels by default.

    To maintain transparency, you need RGBA images (where channel A represents transparency), and ensure that the background remains transparent.

    def load_image(image_path):
        return cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
    
    def save_image(image, path):
        if image.shape[2] == 4:
            b, g, r, a = cv2.split(image)
            white_background = np.ones_like(a, dtype=np.uint8) * 255
            alpha_inv = cv2.bitwise_not(a)
            white_background = cv2.bitwise_and(white_background, white_background, mask=alpha_inv)
            b = cv2.bitwise_or(b, white_background)
            g = cv2.bitwise_or(g, white_background)
            r = cv2.bitwise_or(r, white_background)
            image_no_alpha = cv2.merge([b, g, r])
            image_pil = Image.fromarray(cv2.cvtColor(image_no_alpha, cv2.COLOR_BGR2RGB))
            image_pil.save(path, 'JPEG', quality=95)
        else:
            image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
            image_pil.save(path, 'JPEG', quality=95)
    

    and the final code:

    import os
    import cv2
    import numpy as np
    from PIL import Image
    from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
    
    def load_image(image_path):
        return cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
    
    def save_image(image, path):
        if image.shape[2] == 4:
            b, g, r, a = cv2.split(image)
            white_background = np.ones_like(a, dtype=np.uint8) * 255
            alpha_inv = cv2.bitwise_not(a)
            white_background = cv2.bitwise_and(white_background, white_background, mask=alpha_inv)
            b = cv2.bitwise_or(b, white_background)
            g = cv2.bitwise_or(g, white_background)
            r = cv2.bitwise_or(r, white_background)
            image_no_alpha = cv2.merge([b, g, r])
            image_pil = Image.fromarray(cv2.cvtColor(image_no_alpha, cv2.COLOR_BGR2RGB))
            image_pil.save(path, 'JPEG', quality=95)
        else:
            image_pil = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
            image_pil.save(path, 'JPEG', quality=95)
    
    def select_object(image):
        long_side = 1024
        height, width = image.shape[:2]
        if max(height, width) > long_side:
            scale = long_side / max(height, width)
            new_size = (int(width * scale), int(height * scale))
            resized_image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
        else:
            resized_image = image
    
        if resized_image.shape[2] == 4:
            resized_image = cv2.cvtColor(resized_image, cv2.COLOR_BGRA2BGR)
        
        sam = sam_model_registry["vit_b"](checkpoint="env\\Checkpoint\\sam_vit_b_01ec64.pth")
        mask_generator = SamAutomaticMaskGenerator(sam)
        masks = mask_generator.generate(resized_image)
        largest_mask = max(masks, key=lambda x: x['area'])
        
        mask = largest_mask['segmentation']
        mask = np.logical_not(mask).astype(np.uint8)
        
        original_size_mask = cv2.resize(mask, (width, height), interpolation=cv2.INTER_NEAREST)
        
        return original_size_mask
    
    def crop_to_object(image, mask):
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        x, y, w, h = cv2.boundingRect(contours[0])
        cropped_image = image[y:y+h, x:x+w]
        return cropped_image
    
    def resize_with_padding(image, size):
        height, width = image.shape[:2]
        scale = size / max(height, width)
        new_size = (int(width * scale), int(height * scale))
        resized_image = cv2.resize(image, new_size, interpolation=cv2.INTER_AREA)
        
        delta_w = size - new_size[0]
        delta_h = size - new_size[1]
        top, bottom = delta_h // 2, delta_h - (delta_h // 2)
        left, right = delta_w // 2, delta_w - (delta_w // 2)
        
        color = [255, 255, 255]
        new_image = cv2.copyMakeBorder(resized_image, top, bottom, left, right, cv2.BORDER_CONSTANT, value=color)
        
        return new_image
    
    def process_image(image_path, output_path):
        image = load_image(image_path)
        mask = select_object(image)
        cropped_image = crop_to_object(image, mask)
        
        final_image = resize_with_padding(cropped_image, 800)
        save_image(final_image, output_path)
    
    def process_folder(input_folder, output_folder):
        if not os.path.exists(output_folder):
            os.makedirs(output_folder)
        for root, _, files in os.walk(input_folder):
            for filename in files:
                if filename.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
                    input_path = os.path.join(root, filename)
                    relative_path = os.path.relpath(input_path, input_folder)
                    output_path = os.path.join(output_folder, relative_path)
                    output_path = os.path.splitext(output_path)[0] + '.jpg'
                    output_dir = os.path.dirname(output_path)
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    try:
                        process_image(input_path, output_path)
                        print(f"Processed {input_path}")
                    except Exception as e:
                        print(f"Failed to process {input_path}: {e}")
    
    if __name__ == "__main__":
        input_folder = ""
        output_folder = ""
        process_folder(input_folder, output_folder)