Search code examples
pythontensorflowmachine-learningimage-processingdeep-learning

image reconstruction from predicted array - padding same shows grid tiles in reconstructed image


I have two images, E1 and E3, and I am training a CNN model.

In order to train the model, I use E1 as train and E3 as y_train.

I extract tiles from these images in order to train the model on tiles.

The model, does not have an activation layer, so the output can take any value.

So, the predictions for example, preds , have values around preds.max() = 2.35 and preds.min() = -1.77.

My problem is that I can't reconstruct the image at the end using preds and I think the problem is the scaling-unscaling of the preds values.

If I just do np.uint8(preds) its is almost full of zeros since preds has small values.

The image should look like as close as possible to E2 image.

import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, \
    Input, Add
from tensorflow.keras.models import Model
from PIL import Image

CHANNELS = 1
HEIGHT = 32
WIDTH = 32
INIT_SIZE = ((1429, 1416))

def NormalizeData(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data) + 1e-6)

def extract_image_tiles(size, im):
    im = im[:, :, :CHANNELS]
    w = h = size
    idxs = [(i, (i + h), j, (j + w)) for i in range(0, im.shape[0], h) for j in range(0, im.shape[1], w)]
    tiles_asarrays = []
    count = 0
    for k, (i_start, i_end, j_start, j_end) in enumerate(idxs):
        tile = im[i_start:i_end, j_start:j_end, ...]
        if tile.shape[:2] != (h, w):
            tile_ = tile
            tile_size = (h, w) if tile.ndim == 2 else (h, w, tile.shape[2])
            tile = np.zeros(tile_size, dtype=tile.dtype)
            tile[:tile_.shape[0], :tile_.shape[1], ...] = tile_
        
        count += 1
        tiles_asarrays.append(tile)
    return np.array(idxs), np.array(tiles_asarrays)


def build_model(height, width, channels):
    inputs = Input((height, width, channels))

    f1 = Conv2D(32, 3, padding='same')(inputs)
    f1 = BatchNormalization()(f1)
    f1 = Activation('relu')(f1)
    
    f2 = Conv2D(16, 3, padding='same')(f1)
    f2 = BatchNormalization()(f2)
    f2 = Activation('relu')(f2)
    
    f3 = Conv2D(16, 3, padding='same')(f2)
    f3 = BatchNormalization()(f3)
    f3 = Activation('relu')(f3)

    addition = Add()([f2, f3])
    
    f4 = Conv2D(32, 3, padding='same')(addition)
    
    f5 = Conv2D(16, 3, padding='same')(f4)
    f5 = BatchNormalization()(f5)
    f5 = Activation('relu')(f5)
   
    f6 = Conv2D(16, 3, padding='same')(f5)
    f6 = BatchNormalization()(f6)
    f6 = Activation('relu')(f6)
   
    output = Conv2D(1, 1, padding='same')(f6)

    model = Model(inputs, output)

    return model

# Load data
img = cv2.imread('E1.tif', cv2.IMREAD_UNCHANGED)
img = cv2.resize(img, (1408, 1408), interpolation=cv2.INTER_AREA)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.array(img, np.uint8)
#plt.imshow(img)
img3 = cv2.imread('E3.tif', cv2.IMREAD_UNCHANGED)
img3 = cv2.resize(img3, (1408, 1408), interpolation=cv2.INTER_AREA)
img3 = cv2.cvtColor(img3, cv2.COLOR_BGR2RGB)
img3 = np.array(img3, np.uint8)

# extract tiles from images
idxs, tiles = extract_image_tiles(WIDTH, img)
idxs2, tiles3 = extract_image_tiles(WIDTH, img3)

# split to train and test data
split_idx = int(tiles.shape[0] * 0.9)

train = tiles[:split_idx]
val = tiles[split_idx:]

y_train = tiles3[:split_idx]
y_val = tiles3[split_idx:]

# build model
model = build_model(HEIGHT, WIDTH, CHANNELS)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss = tf.keras.losses.Huber(),
              metrics=[tf.keras.metrics.RootMeanSquaredError(name='rmse')])

# scale data before training
train  = train / 255.
val = val / 255.

y_train = y_train / 255.
y_val = y_val / 255.

# train
history = model.fit(train, 
                    y_train, 
                    validation_data=(val, y_val),
                    epochs=50)

# predict on E2
img2 = cv2.imread('E2.tif', cv2.IMREAD_UNCHANGED)
img2 = cv2.resize(img2, (1408, 1408), interpolation=cv2.INTER_AREA)
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
img2 = np.array(img2, np.uint8)

# extract tiles from images
idxs, tiles2 = extract_image_tiles(WIDTH, img2)

#scale data
tiles2 = tiles2 / 255.

preds = model.predict(tiles2)
#preds = NormalizeData(preds)
#preds = np.uint8(preds)
# reconstruct predictions
reconstructed = np.zeros((img.shape[0],
                          img.shape[1]),
                          dtype=np.uint8)

# reconstruction process
for tile, (y_start, y_end, x_start, x_end) in zip(preds[:, :, -1], idxs):
    y_end = min(y_end, img.shape[0])
    x_end = min(x_end, img.shape[1])
    reconstructed[y_start:y_end, x_start:x_end] = tile[:(y_end - y_start), :(x_end - x_start)]


im = Image.fromarray(reconstructed)
im = im.resize(INIT_SIZE)
im.show()

You can find the data here

If I use :

def normalize_arr_to_uint8(arr):
  the_min = arr.min()
  the_max = arr.max()
  the_max -= the_min
  arr = ((arr - the_min) / the_max) * 255.
  return arr.astype(np.uint8)


preds = model.predict(tiles2)
preds = normalize_arr_to_uint8(preds)

then, I receive an image which seems right, but with lines all over.

Here is the image I get:

enter image description here

This is the image I should take (as close as possible to E2). Note, that I just use a small cnn network for this example, so I can't receive receive much details for the image. But, when I try better model, still I have horizontal and/or vertical lines:

enter image description here

UPDATE

I found this.

In the code above, I use at:

# reconstruction process
for tile, (y_start, y_end, x_start, x_end) in zip(preds[:, :, -1], idxs):

preds[:, :, -1] this is wrong.

I must use preds[:, :, :, -1] because preds shape is: (1936, 32, 32, 1).

So, If I use preds[:, :, -1] I am receiving the image I posted.

If I use preds[:, :, :, -1], which is right , I receive a new image where except from the horizontal lines, I get vertical lines also!

the new image

UPDATE 2

I am just adding new code where I use another patches and reconstruction functions, but produce the same results (a little better picture).

import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, \
    Input, Add
from tensorflow.keras.models import Model
from PIL import Image

# gpu setup
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
    
CHANNELS = 1
HEIGHT = 1408
WIDTH = 1408
PATCH_SIZE = 32
STRIDE = PATCH_SIZE//2
INIT_SIZE = ((1429, 1416))

def normalize_arr_to_uint8(arr):
  the_min = arr.min()
  the_max = arr.max()
  the_max -= the_min
  arr = ((arr - the_min) / the_max) * 255.
  return arr.astype(np.uint8)


def NormalizeData(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data) + 1e-6)


def recon_im(patches: np.ndarray, im_h: int, im_w: int, n_channels: int, stride: int):
    """Reconstruct the image from all patches.
        Patches are assumed to be square and overlapping depending on the stride. The image is constructed
         by filling in the patches from left to right, top to bottom, averaging the overlapping parts.

    Parameters
    -----------
    patches: 4D ndarray with shape (patch_number,patch_height,patch_width,channels)
        Array containing extracted patches. If the patches contain colour information,
        channels are indexed along the last dimension: RGB patches would
        have `n_channels=3`.
    im_h: int
        original height of image to be reconstructed
    im_w: int
        original width of image to be reconstructed
    n_channels: int
        number of channels the image has. For  RGB image, n_channels = 3
    stride: int
           desired patch stride

    Returns
    -----------
    reconstructedim: ndarray with shape (height, width, channels)
                      or ndarray with shape (height, width) if output image only has one channel
                    Reconstructed image from the given patches
    """

    patch_size = patches.shape[1]  # patches assumed to be square

    # Assign output image shape based on patch sizes
    rows = ((im_h - patch_size) // stride) * stride + patch_size
    cols = ((im_w - patch_size) // stride) * stride + patch_size

    if n_channels == 1:
        reconim = np.zeros((rows, cols))
        divim = np.zeros((rows, cols))
    else:
        reconim = np.zeros((rows, cols, n_channels))
        divim = np.zeros((rows, cols, n_channels))

    p_c = (cols - patch_size + stride) / stride  # number of patches needed to fill out a row

    totpatches = patches.shape[0]
    initr, initc = 0, 0

    # extract each patch and place in the zero matrix and sum it with existing pixel values

    reconim[initr:patch_size, initc:patch_size] = patches[0]# fill out top left corner using first patch
    divim[initr:patch_size, initc:patch_size] = np.ones(patches[0].shape)

    patch_num = 1

    while patch_num <= totpatches - 1:
        initc = initc + stride
        reconim[initr:initr + patch_size, initc:patch_size + initc] += patches[patch_num]
        divim[initr:initr + patch_size, initc:patch_size + initc] += np.ones(patches[patch_num].shape)

        if np.remainder(patch_num + 1, p_c) == 0 and patch_num < totpatches - 1:
            initr = initr + stride
            initc = 0
            reconim[initr:initr + patch_size, initc:patch_size] += patches[patch_num + 1]
            divim[initr:initr + patch_size, initc:patch_size] += np.ones(patches[patch_num].shape)
            patch_num += 1
        patch_num += 1
    # Average out pixel values
    reconstructedim = reconim / divim

    return reconstructedim


def get_patches(GT, stride, patch_size):
    """Extracts square patches from an image of any size.
    Parameters
    -----------
    GT : ndarray
        n-dimensional array containing the image from which patches are to be extracted
    stride : int
           desired patch stride
    patch_size : int
               patch size
    Returns
    -----------
    patches: ndarray
            array containing all patches
    im_h: int
        height of image to be reconstructed
    im_w: int
        width of image to be reconstructed
    n_channels: int
        number of channels the image has. For  RGB image, n_channels = 3
    """

    hr_patches = []

    for i in range(0, GT.shape[0] - patch_size + 1, stride):
        for j in range(0, GT.shape[1] - patch_size + 1, stride):
            hr_patches.append(GT[i:i + patch_size, j:j + patch_size])

    im_h, im_w = GT.shape[0], GT.shape[1]

    if len(GT.shape) == 2:
        n_channels = 1
    else:
        n_channels = GT.shape[2]

    patches = np.asarray(hr_patches)

    return patches, im_h, im_w, n_channels


def build_model(height, width, channels):
    inputs = Input((height, width, channels))

    f1 = Conv2D(32, 3, padding='same')(inputs)
    f1 = BatchNormalization()(f1)
    f1 = Activation('relu')(f1)
    
    f2 = Conv2D(16, 3, padding='same')(f1)
    f2 = BatchNormalization()(f2)
    f2 = Activation('relu')(f2)
    
    f3 = Conv2D(16, 3, padding='same')(f2)
    f3 = BatchNormalization()(f3)
    f3 = Activation('relu')(f3)

    addition = Add()([f2, f3])
    
    f4 = Conv2D(32, 3, padding='same')(addition)
    
    f5 = Conv2D(16, 3, padding='same')(f4)
    f5 = BatchNormalization()(f5)
    f5 = Activation('relu')(f5)
   
    f6 = Conv2D(16, 3, padding='same')(f5)
    f6 = BatchNormalization()(f6)
    f6 = Activation('relu')(f6)
   
    output = Conv2D(1, 1, padding='same')(f6)

    model = Model(inputs, output)

    return model

# Load data
img = cv2.imread('E1.tif', cv2.IMREAD_UNCHANGED)
img = cv2.resize(img, (HEIGHT, WIDTH), interpolation=cv2.INTER_AREA)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.array(img, np.uint8)

img3 = cv2.imread('E3.tif', cv2.IMREAD_UNCHANGED)
img3 = cv2.resize(img3, (HEIGHT, WIDTH), interpolation=cv2.INTER_AREA)
img3 = cv2.cvtColor(img3, cv2.COLOR_BGR2RGB)
img3 = np.array(img3, np.uint8)

# extract tiles from images
tiles, H, W, C = get_patches(img[:, :, :CHANNELS], stride=STRIDE, patch_size=PATCH_SIZE)
tiles3, H, W, C = get_patches(img3[:, :, :CHANNELS], stride=STRIDE, patch_size=PATCH_SIZE)


# split to train and test data
split_idx = int(tiles.shape[0] * 0.9)

train = tiles[:split_idx]
val = tiles[split_idx:]

y_train = tiles3[:split_idx]
y_val = tiles3[split_idx:]

# build model
model = build_model(PATCH_SIZE, PATCH_SIZE, CHANNELS)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
              loss = tf.keras.losses.Huber(),
              metrics=[tf.keras.metrics.RootMeanSquaredError(name='rmse')])

# scale data before training
train  = train / 255.
val = val / 255.

y_train = y_train / 255.
y_val = y_val / 255.

# train
history = model.fit(train, 
                    y_train, 
                    validation_data=(val, y_val),
                    batch_size=16,
                    epochs=20)

# predict on E2
img2 = cv2.imread('E2.tif', cv2.IMREAD_UNCHANGED)
img2 = cv2.resize(img2, (HEIGHT, WIDTH), interpolation=cv2.INTER_AREA)
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
img2 = np.array(img2, np.uint8)

# extract tiles from images
tiles2, H, W, CHANNELS = get_patches(img2[:, :, :CHANNELS], stride=STRIDE, patch_size=PATCH_SIZE)

#scale data
tiles2 = tiles2 / 255.


preds = model.predict(tiles2)
preds = normalize_arr_to_uint8(preds)

reconstructed = recon_im(preds[:, :, :, -1], HEIGHT, WIDTH, CHANNELS, stride=STRIDE)

im = Image.fromarray(reconstructed)
im = im.resize(INIT_SIZE)
im.show()

and the image produced:

stride16

UPDATE 3

After @Lescurel comment, I tried this arcitecture:

def build_model(height, width, channels):
    inputs = Input((height, width, channels))

    f1 = Conv2D(32, 3, padding='valid')(inputs)
    f1 = BatchNormalization()(f1)
    f1 = Activation('relu')(f1)
 
    f2 = Conv2D(16, 3, strides=2,padding='valid')(f1)
    f2 = BatchNormalization()(f2)
    f2 = Activation('relu')(f2)
    
    f3 = Conv2D(16, 3, padding='same')(f2)
    f3 = BatchNormalization()(f3)
    f3 = Activation('relu')(f3)

    addition = Add()([f2, f3])
    
    f4 = Conv2D(32, 3, padding='valid')(addition)
    
    f5 = Conv2D(16, 3, padding='valid')(f4)
    f5 = BatchNormalization()(f5)
    f5 = Activation('relu')(f5)
   
    f6 = Conv2D(16, 3, padding='valid')(f5)
    f6 = BatchNormalization()(f6)
    f6 = Activation('relu')(f6)
   
    f7 = Conv2DTranspose(16, 3, strides=2,padding='same')(f6)
    f7 = BatchNormalization()(f7)
    f7 = Activation('relu')(f7)
    
    f8 = Conv2DTranspose(16, 3, strides=2,padding='same')(f7)
    f8 = BatchNormalization()(f8)
    f8 = Activation('relu')(f8)
    
    output = Conv2D(1,1, padding='same', activation='sigmoid')(f8)

    model = Model(inputs, output)

    return model

which uses valid and same padding and the image I receive its:

So, the square tiles changed dimensions and shape.

So, the problem is how can I use my original architecture and get rid of these tiles!

padding


Solution

  • Now that I understood this:

    My problem is not the small black box. This is some bad pixel. My problem is the square tile border lines.

    This is a stitching problem in my opinion. The border lines, come from your approach to crop the full image into tiles and then stitch them back together later. Due to the padding of your convolutions in your model architecture, it is expected to have border artifacts.

    There are two things you can do now:

    1. Train and process on larger tiles and then crop the center tile. This does away with the padding issue by "cutting away the problematic parts"
    def extract_image_tiles_with_overlap(size, stride, im, center_tile_width, overlap_percentage):
    
        im = im[:, :, :CHANNELS]
        w = h = size
        s = stride
        overlap = int(center_tile_width * overlap_percentage / 100)
    
        idxs = [(i, (i + h), j, (j + w)) for i in range(0, im.shape[0] - h + 1, s) for j in
                range(0, im.shape[1] - w + 1, s)]
        tiles_asarrays = []
    
        for k, (i_start, i_end, j_start, j_end) in enumerate(idxs):
            tile = im[i_start:i_end, j_start:j_end, ...]
    
            if tile.shape[:2] != (h, w):
                tile_ = tile
                tile_size = (h, w) if tile.ndim == 2 else (h, w, tile.shape[2])
                tile = np.zeros(tile_size, dtype=tile.dtype)
                tile[:tile_.shape[0], :tile_.shape[1], ...] = tile_
    
            tiles_asarrays.append(tile)
    
        return np.array(idxs), np.array(tiles_asarrays), overlap
    

    The result looks like this: enter image description here

    Of course this does not completely solve the stitching problem entirely. As a next step you can experiment with aggregating overlapping segments of the tiles to get this smoother.

    1. This smoothing through aggregation can also be done as a post-processing step. E.g. via a Gaussian Filtering, or morphological operations.
    iterations = 10  # Adjust the number of iterations as needed
    reconstructed_dilated = binary_dilation(reconstructed, iterations=iterations)
    reconstructed_smoothed = binary_erosion(reconstructed_dilated, iterations=iterations)
    

    Or you use some OpenCV denoising (make sure to further optimize the hyperparameters).

    denoised_reconstructed = cv2.fastNlMeansDenoising(reconstructed, h=10, templateWindowSize=7, searchWindowSize=21)
    

    Result now is this:

    enter image description here

    Full code to reproduce I changed a few other things (e.g. model saving, etc.)

    import cv2
    import numpy as np
    import tensorflow as tf
    import matplotlib.pyplot as plt
    from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, \
        Input, Add
    from tensorflow.keras.models import Model
    from PIL import Image
    import os
    from scipy.ndimage import gaussian_filter
    from scipy.ndimage import binary_dilation, binary_erosion
    
    
    
    CHANNELS = 1
    HEIGHT = 32
    WIDTH = 32
    INIT_SIZE = ((1429, 1416))
    MODEL_SAVE_PATH = 'my_croptile_model.h5'  # Change this path to your desired location
    
    
    
    def NormalizeData(data):
        return (data - np.min(data)) / (np.max(data) - np.min(data) + 1e-6)
    
    def extract_image_tiles_with_overlap(size, stride, im, center_tile_width, overlap_percentage):
        im = im[:, :, :CHANNELS]
        w = h = size
        s = stride
        overlap = int(center_tile_width * overlap_percentage / 100)
    
        idxs = [(i, (i + h), j, (j + w)) for i in range(0, im.shape[0] - h + 1, s) for j in
                range(0, im.shape[1] - w + 1, s)]
        tiles_asarrays = []
    
        for k, (i_start, i_end, j_start, j_end) in enumerate(idxs):
            tile = im[i_start:i_end, j_start:j_end, ...]
    
            if tile.shape[:2] != (h, w):
                tile_ = tile
                tile_size = (h, w) if tile.ndim == 2 else (h, w, tile.shape[2])
                tile = np.zeros(tile_size, dtype=tile.dtype)
                tile[:tile_.shape[0], :tile_.shape[1], ...] = tile_
    
            tiles_asarrays.append(tile)
    
        return np.array(idxs), np.array(tiles_asarrays), overlap
    
    def build_model(height, width, channels):
        inputs = Input((height, width, channels))
    
        f1 = Conv2D(32, 3, padding='same')(inputs)
        f1 = BatchNormalization()(f1)
        f1 = Activation('relu')(f1)
    
        f2 = Conv2D(16, 3, padding='same')(f1)
        f2 = BatchNormalization()(f2)
        f2 = Activation('relu')(f2)
    
        f3 = Conv2D(16, 3, padding='same')(f2)
        f3 = BatchNormalization()(f3)
        f3 = Activation('relu')(f3)
    
        addition = Add()([f2, f3])
    
        f4 = Conv2D(32, 3, padding='same')(addition)
    
        f5 = Conv2D(16, 3, padding='same')(f4)
        f5 = BatchNormalization()(f5)
        f5 = Activation('relu')(f5)
    
        f6 = Conv2D(16, 3, padding='same')(f5)
        f6 = BatchNormalization()(f6)
        f6 = Activation('relu')(f6)
    
        output = Conv2D(1, 1, padding='same')(f6)
    
        model = Model(inputs, output)
    
        return model
    
    
    # Load data
    img = cv2.imread('images/E1.tif', cv2.IMREAD_UNCHANGED)
    img = cv2.resize(img, (1408, 1408), interpolation=cv2.INTER_AREA)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = np.array(img, np.uint8)
    # plt.imshow(img)
    img3 = cv2.imread('images/E3.tif', cv2.IMREAD_UNCHANGED)
    img3 = cv2.resize(img3, (1408, 1408), interpolation=cv2.INTER_AREA)
    img3 = cv2.cvtColor(img3, cv2.COLOR_BGR2RGB)
    img3 = np.array(img3, np.uint8)
    
    # extract tiles from images
    # idxs, tiles = extract_image_tiles(WIDTH, img)
    # idxs2, tiles3 = extract_image_tiles(WIDTH, img3)
    
    # extract tiles from images with overlap
    center_tile_width = WIDTH  # Adjust as needed
    overlap_percentage = 60  # Adjust as needed
    idxs, tiles, overlap = extract_image_tiles_with_overlap(WIDTH, WIDTH // 2, img, center_tile_width, overlap_percentage)
    
    
    
    # split to train and test data
    split_idx = int(tiles.shape[0] * 0.9)
    
    train = tiles[:split_idx]
    val = tiles[split_idx:]
    
    y_train = tiles[:split_idx]
    y_val = tiles[split_idx:]
    
    
    # Build or load model
    if os.path.exists(MODEL_SAVE_PATH):
        model = tf.keras.models.load_model(MODEL_SAVE_PATH)
    
        model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
                      loss=tf.keras.losses.Huber(),
                      metrics=[tf.keras.metrics.RootMeanSquaredError(name='rmse')])
    
        # scale data before training
        train = train / 255.
        val = val / 255.
    
        y_train = y_train / 255.
        y_val = y_val / 255.
    
        # train
        history = model.fit(train,
                            y_train,
                            validation_data=(val, y_val),
                            epochs=0)  # Adjust epochs as needed
    
        # Save the model
        model.save(MODEL_SAVE_PATH)
    else:
        model = build_model(HEIGHT, WIDTH, CHANNELS)
    
        model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
                      loss=tf.keras.losses.Huber(),
                      metrics=[tf.keras.metrics.RootMeanSquaredError(name='rmse')])
    
        # scale data before training
        train = train / 255.
        val = val / 255.
    
        y_train = y_train / 255.
        y_val = y_val / 255.
    
        # train
        history = model.fit(train,
                            y_train,
                            validation_data=(val, y_val),
                            epochs=50)  # Adjust epochs as needed
    
        # Save the model
        model.save(MODEL_SAVE_PATH)
    
    # predict on E2
    img2 = cv2.imread('images/E2.tif', cv2.IMREAD_UNCHANGED)
    img2 = cv2.resize(img2, (1408, 1408), interpolation=cv2.INTER_AREA)
    img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
    img2 = np.array(img2, np.uint8)
    
    # extract tiles from images
    idxs, tiles2, overlap = extract_image_tiles_with_overlap(WIDTH, WIDTH // 2, img2, center_tile_width, overlap_percentage)
    
    # scale data
    tiles2 = tiles2 / 255.
    
    preds = model.predict(tiles2)
    
    # Check model output range
    print("Max prediction value:", np.max(preds))
    print("Min prediction value:", np.min(preds))
    
    # Invert colors in predictions
    inverted_preds = 1.0 - preds
    
    # Ensure values are within valid range
    inverted_preds = np.clip(inverted_preds, 0, 1)
    
    # Reconstruct inverted predictions with cropping center part of the tile
    reconstructed = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
    
    # Reconstruction process
    for (y_start, y_end, x_start, x_end), tile in zip(idxs, inverted_preds[:, :, :, -1]):
        y_end = min(y_end, img.shape[0])
        x_end = min(x_end, img.shape[1])
    
        # Calculate the crop size based on the size of the original tile without overlap
        crop_size = WIDTH
    
        # Rescale the tile to the original size (WIDTH + overlap)
        rescaled_tile = cv2.resize(tile, (WIDTH + overlap, WIDTH + overlap))
    
        # Crop the center part of the rescaled tile
        center_crop = rescaled_tile[overlap//2:overlap//2+crop_size, overlap//2:overlap//2+crop_size]
    
        # Update the reconstructed image directly with the cropped tile
        reconstructed[y_start:y_end, x_start:x_end] = (center_crop * 255).astype(np.uint8)
    
    # Apply binary dilation and erosion to enhance the tile boundaries
    # Apply non-local means denoising
    denoised_reconstructed = cv2.fastNlMeansDenoising(reconstructed, h=10, templateWindowSize=7, searchWindowSize=21)
    
    
    
    im = Image.fromarray(denoised_reconstructed)
    im = im.resize(INIT_SIZE)
    im.show()
    

    The updated answer ends here.


    So I think this is mainly about post processing and visualization of your data:

    enter image description here

    I visualize like this now:

    # predict on E2
    img2 = cv2.imread('images/E2.tif', cv2.IMREAD_UNCHANGED)
    img2 = cv2.resize(img2, (1408, 1408), interpolation=cv2.INTER_AREA)
    img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
    img2 = np.array(img2, np.uint8)
    
    # extract tiles from images
    idxs, tiles2 = extract_image_tiles(WIDTH, img2)
    
    # scale data
    tiles2 = tiles2 / 255.
    
    preds = model.predict(tiles2)
    
    # Check model output range
    print("Max prediction value:", np.max(preds))
    print("Min prediction value:", np.min(preds))
    
    # Invert colors in predictions
    inverted_preds = 1.0 - preds
    
    # Ensure values are within valid range
    inverted_preds = np.clip(inverted_preds, 0, 1)
    
    # Reconstruct inverted predictions
    reconstructed = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
    
    # Reconstruction process
    for tile, (y_start, y_end, x_start, x_end) in zip(inverted_preds[:, :, :, -1], idxs):
        y_end = min(y_end, img.shape[0])
        x_end = min(x_end, img.shape[1])
        reconstructed[y_start:y_end, x_start:x_end] = (tile * 255).astype(np.uint8)
    
    im = Image.fromarray(reconstructed)
    im = im.resize(INIT_SIZE)
    im.show()
    

    I can already clearly see the black square. So next I will be trying some thresholding to get enhance that visibility in post-processing.

    I am thinking of something like that:

    # Threshold value (adjust as needed)
    threshold = 0.9 #0.45
    
    # Thresholding
    binary_output = (reconstructed >= threshold).astype(np.uint8) * 255
    
    # Second visualization
    im_binary = Image.fromarray(binary_output)
    im_binary = im_binary.resize(INIT_SIZE)
    im_binary.show()
    

    Which leaves me with this:

    enter image description here

    Not sure how good this scales across your full dataset, but this is definitely in the ball-park for some morphological operators in post-processing.