I have two images, E1
and E3
, and I am training a CNN model.
In order to train the model, I use E1
as train and E3
as y_train.
I extract tiles from these images in order to train the model on tiles.
The model, does not have an activation layer, so the output can take any value.
So, the predictions for example, preds
, have values around preds.max() = 2.35
and preds.min() = -1.77
.
My problem is that I can't reconstruct the image at the end using preds
and I think the problem is the scaling-unscaling of the preds
values.
If I just do np.uint8(preds)
its is almost full of zeros since preds
has small values.
The image should look like as close as possible to E2
image.
import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, \
Input, Add
from tensorflow.keras.models import Model
from PIL import Image
CHANNELS = 1
HEIGHT = 32
WIDTH = 32
INIT_SIZE = ((1429, 1416))
def NormalizeData(data):
return (data - np.min(data)) / (np.max(data) - np.min(data) + 1e-6)
def extract_image_tiles(size, im):
im = im[:, :, :CHANNELS]
w = h = size
idxs = [(i, (i + h), j, (j + w)) for i in range(0, im.shape[0], h) for j in range(0, im.shape[1], w)]
tiles_asarrays = []
count = 0
for k, (i_start, i_end, j_start, j_end) in enumerate(idxs):
tile = im[i_start:i_end, j_start:j_end, ...]
if tile.shape[:2] != (h, w):
tile_ = tile
tile_size = (h, w) if tile.ndim == 2 else (h, w, tile.shape[2])
tile = np.zeros(tile_size, dtype=tile.dtype)
tile[:tile_.shape[0], :tile_.shape[1], ...] = tile_
count += 1
tiles_asarrays.append(tile)
return np.array(idxs), np.array(tiles_asarrays)
def build_model(height, width, channels):
inputs = Input((height, width, channels))
f1 = Conv2D(32, 3, padding='same')(inputs)
f1 = BatchNormalization()(f1)
f1 = Activation('relu')(f1)
f2 = Conv2D(16, 3, padding='same')(f1)
f2 = BatchNormalization()(f2)
f2 = Activation('relu')(f2)
f3 = Conv2D(16, 3, padding='same')(f2)
f3 = BatchNormalization()(f3)
f3 = Activation('relu')(f3)
addition = Add()([f2, f3])
f4 = Conv2D(32, 3, padding='same')(addition)
f5 = Conv2D(16, 3, padding='same')(f4)
f5 = BatchNormalization()(f5)
f5 = Activation('relu')(f5)
f6 = Conv2D(16, 3, padding='same')(f5)
f6 = BatchNormalization()(f6)
f6 = Activation('relu')(f6)
output = Conv2D(1, 1, padding='same')(f6)
model = Model(inputs, output)
return model
# Load data
img = cv2.imread('E1.tif', cv2.IMREAD_UNCHANGED)
img = cv2.resize(img, (1408, 1408), interpolation=cv2.INTER_AREA)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.array(img, np.uint8)
#plt.imshow(img)
img3 = cv2.imread('E3.tif', cv2.IMREAD_UNCHANGED)
img3 = cv2.resize(img3, (1408, 1408), interpolation=cv2.INTER_AREA)
img3 = cv2.cvtColor(img3, cv2.COLOR_BGR2RGB)
img3 = np.array(img3, np.uint8)
# extract tiles from images
idxs, tiles = extract_image_tiles(WIDTH, img)
idxs2, tiles3 = extract_image_tiles(WIDTH, img3)
# split to train and test data
split_idx = int(tiles.shape[0] * 0.9)
train = tiles[:split_idx]
val = tiles[split_idx:]
y_train = tiles3[:split_idx]
y_val = tiles3[split_idx:]
# build model
model = build_model(HEIGHT, WIDTH, CHANNELS)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss = tf.keras.losses.Huber(),
metrics=[tf.keras.metrics.RootMeanSquaredError(name='rmse')])
# scale data before training
train = train / 255.
val = val / 255.
y_train = y_train / 255.
y_val = y_val / 255.
# train
history = model.fit(train,
y_train,
validation_data=(val, y_val),
epochs=50)
# predict on E2
img2 = cv2.imread('E2.tif', cv2.IMREAD_UNCHANGED)
img2 = cv2.resize(img2, (1408, 1408), interpolation=cv2.INTER_AREA)
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
img2 = np.array(img2, np.uint8)
# extract tiles from images
idxs, tiles2 = extract_image_tiles(WIDTH, img2)
#scale data
tiles2 = tiles2 / 255.
preds = model.predict(tiles2)
#preds = NormalizeData(preds)
#preds = np.uint8(preds)
# reconstruct predictions
reconstructed = np.zeros((img.shape[0],
img.shape[1]),
dtype=np.uint8)
# reconstruction process
for tile, (y_start, y_end, x_start, x_end) in zip(preds[:, :, -1], idxs):
y_end = min(y_end, img.shape[0])
x_end = min(x_end, img.shape[1])
reconstructed[y_start:y_end, x_start:x_end] = tile[:(y_end - y_start), :(x_end - x_start)]
im = Image.fromarray(reconstructed)
im = im.resize(INIT_SIZE)
im.show()
You can find the data here
If I use :
def normalize_arr_to_uint8(arr):
the_min = arr.min()
the_max = arr.max()
the_max -= the_min
arr = ((arr - the_min) / the_max) * 255.
return arr.astype(np.uint8)
preds = model.predict(tiles2)
preds = normalize_arr_to_uint8(preds)
then, I receive an image which seems right, but with lines all over.
Here is the image I get:
This is the image I should take (as close as possible to E2
). Note, that I just use a small cnn network for this example, so I can't receive receive much details for the image. But, when I try better model, still I have horizontal and/or vertical lines:
UPDATE
I found this.
In the code above, I use at:
# reconstruction process
for tile, (y_start, y_end, x_start, x_end) in zip(preds[:, :, -1], idxs):
preds[:, :, -1]
this is wrong.
I must use preds[:, :, :, -1]
because preds
shape is: (1936, 32, 32, 1)
.
So, If I use preds[:, :, -1]
I am receiving the image I posted.
If I use preds[:, :, :, -1]
, which is right , I receive a new image where except from the horizontal lines, I get vertical lines also!
UPDATE 2
I am just adding new code where I use another patches and reconstruction functions, but produce the same results (a little better picture).
import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, \
Input, Add
from tensorflow.keras.models import Model
from PIL import Image
# gpu setup
gpus = tf.config.experimental.list_physical_devices('GPU')
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
CHANNELS = 1
HEIGHT = 1408
WIDTH = 1408
PATCH_SIZE = 32
STRIDE = PATCH_SIZE//2
INIT_SIZE = ((1429, 1416))
def normalize_arr_to_uint8(arr):
the_min = arr.min()
the_max = arr.max()
the_max -= the_min
arr = ((arr - the_min) / the_max) * 255.
return arr.astype(np.uint8)
def NormalizeData(data):
return (data - np.min(data)) / (np.max(data) - np.min(data) + 1e-6)
def recon_im(patches: np.ndarray, im_h: int, im_w: int, n_channels: int, stride: int):
"""Reconstruct the image from all patches.
Patches are assumed to be square and overlapping depending on the stride. The image is constructed
by filling in the patches from left to right, top to bottom, averaging the overlapping parts.
Parameters
-----------
patches: 4D ndarray with shape (patch_number,patch_height,patch_width,channels)
Array containing extracted patches. If the patches contain colour information,
channels are indexed along the last dimension: RGB patches would
have `n_channels=3`.
im_h: int
original height of image to be reconstructed
im_w: int
original width of image to be reconstructed
n_channels: int
number of channels the image has. For RGB image, n_channels = 3
stride: int
desired patch stride
Returns
-----------
reconstructedim: ndarray with shape (height, width, channels)
or ndarray with shape (height, width) if output image only has one channel
Reconstructed image from the given patches
"""
patch_size = patches.shape[1] # patches assumed to be square
# Assign output image shape based on patch sizes
rows = ((im_h - patch_size) // stride) * stride + patch_size
cols = ((im_w - patch_size) // stride) * stride + patch_size
if n_channels == 1:
reconim = np.zeros((rows, cols))
divim = np.zeros((rows, cols))
else:
reconim = np.zeros((rows, cols, n_channels))
divim = np.zeros((rows, cols, n_channels))
p_c = (cols - patch_size + stride) / stride # number of patches needed to fill out a row
totpatches = patches.shape[0]
initr, initc = 0, 0
# extract each patch and place in the zero matrix and sum it with existing pixel values
reconim[initr:patch_size, initc:patch_size] = patches[0]# fill out top left corner using first patch
divim[initr:patch_size, initc:patch_size] = np.ones(patches[0].shape)
patch_num = 1
while patch_num <= totpatches - 1:
initc = initc + stride
reconim[initr:initr + patch_size, initc:patch_size + initc] += patches[patch_num]
divim[initr:initr + patch_size, initc:patch_size + initc] += np.ones(patches[patch_num].shape)
if np.remainder(patch_num + 1, p_c) == 0 and patch_num < totpatches - 1:
initr = initr + stride
initc = 0
reconim[initr:initr + patch_size, initc:patch_size] += patches[patch_num + 1]
divim[initr:initr + patch_size, initc:patch_size] += np.ones(patches[patch_num].shape)
patch_num += 1
patch_num += 1
# Average out pixel values
reconstructedim = reconim / divim
return reconstructedim
def get_patches(GT, stride, patch_size):
"""Extracts square patches from an image of any size.
Parameters
-----------
GT : ndarray
n-dimensional array containing the image from which patches are to be extracted
stride : int
desired patch stride
patch_size : int
patch size
Returns
-----------
patches: ndarray
array containing all patches
im_h: int
height of image to be reconstructed
im_w: int
width of image to be reconstructed
n_channels: int
number of channels the image has. For RGB image, n_channels = 3
"""
hr_patches = []
for i in range(0, GT.shape[0] - patch_size + 1, stride):
for j in range(0, GT.shape[1] - patch_size + 1, stride):
hr_patches.append(GT[i:i + patch_size, j:j + patch_size])
im_h, im_w = GT.shape[0], GT.shape[1]
if len(GT.shape) == 2:
n_channels = 1
else:
n_channels = GT.shape[2]
patches = np.asarray(hr_patches)
return patches, im_h, im_w, n_channels
def build_model(height, width, channels):
inputs = Input((height, width, channels))
f1 = Conv2D(32, 3, padding='same')(inputs)
f1 = BatchNormalization()(f1)
f1 = Activation('relu')(f1)
f2 = Conv2D(16, 3, padding='same')(f1)
f2 = BatchNormalization()(f2)
f2 = Activation('relu')(f2)
f3 = Conv2D(16, 3, padding='same')(f2)
f3 = BatchNormalization()(f3)
f3 = Activation('relu')(f3)
addition = Add()([f2, f3])
f4 = Conv2D(32, 3, padding='same')(addition)
f5 = Conv2D(16, 3, padding='same')(f4)
f5 = BatchNormalization()(f5)
f5 = Activation('relu')(f5)
f6 = Conv2D(16, 3, padding='same')(f5)
f6 = BatchNormalization()(f6)
f6 = Activation('relu')(f6)
output = Conv2D(1, 1, padding='same')(f6)
model = Model(inputs, output)
return model
# Load data
img = cv2.imread('E1.tif', cv2.IMREAD_UNCHANGED)
img = cv2.resize(img, (HEIGHT, WIDTH), interpolation=cv2.INTER_AREA)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.array(img, np.uint8)
img3 = cv2.imread('E3.tif', cv2.IMREAD_UNCHANGED)
img3 = cv2.resize(img3, (HEIGHT, WIDTH), interpolation=cv2.INTER_AREA)
img3 = cv2.cvtColor(img3, cv2.COLOR_BGR2RGB)
img3 = np.array(img3, np.uint8)
# extract tiles from images
tiles, H, W, C = get_patches(img[:, :, :CHANNELS], stride=STRIDE, patch_size=PATCH_SIZE)
tiles3, H, W, C = get_patches(img3[:, :, :CHANNELS], stride=STRIDE, patch_size=PATCH_SIZE)
# split to train and test data
split_idx = int(tiles.shape[0] * 0.9)
train = tiles[:split_idx]
val = tiles[split_idx:]
y_train = tiles3[:split_idx]
y_val = tiles3[split_idx:]
# build model
model = build_model(PATCH_SIZE, PATCH_SIZE, CHANNELS)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss = tf.keras.losses.Huber(),
metrics=[tf.keras.metrics.RootMeanSquaredError(name='rmse')])
# scale data before training
train = train / 255.
val = val / 255.
y_train = y_train / 255.
y_val = y_val / 255.
# train
history = model.fit(train,
y_train,
validation_data=(val, y_val),
batch_size=16,
epochs=20)
# predict on E2
img2 = cv2.imread('E2.tif', cv2.IMREAD_UNCHANGED)
img2 = cv2.resize(img2, (HEIGHT, WIDTH), interpolation=cv2.INTER_AREA)
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
img2 = np.array(img2, np.uint8)
# extract tiles from images
tiles2, H, W, CHANNELS = get_patches(img2[:, :, :CHANNELS], stride=STRIDE, patch_size=PATCH_SIZE)
#scale data
tiles2 = tiles2 / 255.
preds = model.predict(tiles2)
preds = normalize_arr_to_uint8(preds)
reconstructed = recon_im(preds[:, :, :, -1], HEIGHT, WIDTH, CHANNELS, stride=STRIDE)
im = Image.fromarray(reconstructed)
im = im.resize(INIT_SIZE)
im.show()
and the image produced:
UPDATE 3
After @Lescurel comment, I tried this arcitecture:
def build_model(height, width, channels):
inputs = Input((height, width, channels))
f1 = Conv2D(32, 3, padding='valid')(inputs)
f1 = BatchNormalization()(f1)
f1 = Activation('relu')(f1)
f2 = Conv2D(16, 3, strides=2,padding='valid')(f1)
f2 = BatchNormalization()(f2)
f2 = Activation('relu')(f2)
f3 = Conv2D(16, 3, padding='same')(f2)
f3 = BatchNormalization()(f3)
f3 = Activation('relu')(f3)
addition = Add()([f2, f3])
f4 = Conv2D(32, 3, padding='valid')(addition)
f5 = Conv2D(16, 3, padding='valid')(f4)
f5 = BatchNormalization()(f5)
f5 = Activation('relu')(f5)
f6 = Conv2D(16, 3, padding='valid')(f5)
f6 = BatchNormalization()(f6)
f6 = Activation('relu')(f6)
f7 = Conv2DTranspose(16, 3, strides=2,padding='same')(f6)
f7 = BatchNormalization()(f7)
f7 = Activation('relu')(f7)
f8 = Conv2DTranspose(16, 3, strides=2,padding='same')(f7)
f8 = BatchNormalization()(f8)
f8 = Activation('relu')(f8)
output = Conv2D(1,1, padding='same', activation='sigmoid')(f8)
model = Model(inputs, output)
return model
which uses valid and same padding and the image I receive its:
So, the square tiles changed dimensions and shape.
So, the problem is how can I use my original architecture and get rid of these tiles!
Now that I understood this:
My problem is not the small black box. This is some bad pixel. My problem is the square tile border lines.
This is a stitching problem in my opinion. The border lines, come from your approach to crop the full image into tiles and then stitch them back together later. Due to the padding of your convolutions in your model architecture, it is expected to have border artifacts.
There are two things you can do now:
- Train and process on larger tiles and then crop the center tile. This does away with the padding issue by "cutting away the problematic parts"
def extract_image_tiles_with_overlap(size, stride, im, center_tile_width, overlap_percentage):
im = im[:, :, :CHANNELS]
w = h = size
s = stride
overlap = int(center_tile_width * overlap_percentage / 100)
idxs = [(i, (i + h), j, (j + w)) for i in range(0, im.shape[0] - h + 1, s) for j in
range(0, im.shape[1] - w + 1, s)]
tiles_asarrays = []
for k, (i_start, i_end, j_start, j_end) in enumerate(idxs):
tile = im[i_start:i_end, j_start:j_end, ...]
if tile.shape[:2] != (h, w):
tile_ = tile
tile_size = (h, w) if tile.ndim == 2 else (h, w, tile.shape[2])
tile = np.zeros(tile_size, dtype=tile.dtype)
tile[:tile_.shape[0], :tile_.shape[1], ...] = tile_
tiles_asarrays.append(tile)
return np.array(idxs), np.array(tiles_asarrays), overlap
Of course this does not completely solve the stitching problem entirely. As a next step you can experiment with aggregating overlapping segments of the tiles to get this smoother.
- This smoothing through aggregation can also be done as a post-processing step. E.g. via a Gaussian Filtering, or morphological operations.
iterations = 10 # Adjust the number of iterations as needed
reconstructed_dilated = binary_dilation(reconstructed, iterations=iterations)
reconstructed_smoothed = binary_erosion(reconstructed_dilated, iterations=iterations)
Or you use some OpenCV denoising (make sure to further optimize the hyperparameters).
denoised_reconstructed = cv2.fastNlMeansDenoising(reconstructed, h=10, templateWindowSize=7, searchWindowSize=21)
Result now is this:
Full code to reproduce I changed a few other things (e.g. model saving, etc.)
import cv2
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, \
Input, Add
from tensorflow.keras.models import Model
from PIL import Image
import os
from scipy.ndimage import gaussian_filter
from scipy.ndimage import binary_dilation, binary_erosion
CHANNELS = 1
HEIGHT = 32
WIDTH = 32
INIT_SIZE = ((1429, 1416))
MODEL_SAVE_PATH = 'my_croptile_model.h5' # Change this path to your desired location
def NormalizeData(data):
return (data - np.min(data)) / (np.max(data) - np.min(data) + 1e-6)
def extract_image_tiles_with_overlap(size, stride, im, center_tile_width, overlap_percentage):
im = im[:, :, :CHANNELS]
w = h = size
s = stride
overlap = int(center_tile_width * overlap_percentage / 100)
idxs = [(i, (i + h), j, (j + w)) for i in range(0, im.shape[0] - h + 1, s) for j in
range(0, im.shape[1] - w + 1, s)]
tiles_asarrays = []
for k, (i_start, i_end, j_start, j_end) in enumerate(idxs):
tile = im[i_start:i_end, j_start:j_end, ...]
if tile.shape[:2] != (h, w):
tile_ = tile
tile_size = (h, w) if tile.ndim == 2 else (h, w, tile.shape[2])
tile = np.zeros(tile_size, dtype=tile.dtype)
tile[:tile_.shape[0], :tile_.shape[1], ...] = tile_
tiles_asarrays.append(tile)
return np.array(idxs), np.array(tiles_asarrays), overlap
def build_model(height, width, channels):
inputs = Input((height, width, channels))
f1 = Conv2D(32, 3, padding='same')(inputs)
f1 = BatchNormalization()(f1)
f1 = Activation('relu')(f1)
f2 = Conv2D(16, 3, padding='same')(f1)
f2 = BatchNormalization()(f2)
f2 = Activation('relu')(f2)
f3 = Conv2D(16, 3, padding='same')(f2)
f3 = BatchNormalization()(f3)
f3 = Activation('relu')(f3)
addition = Add()([f2, f3])
f4 = Conv2D(32, 3, padding='same')(addition)
f5 = Conv2D(16, 3, padding='same')(f4)
f5 = BatchNormalization()(f5)
f5 = Activation('relu')(f5)
f6 = Conv2D(16, 3, padding='same')(f5)
f6 = BatchNormalization()(f6)
f6 = Activation('relu')(f6)
output = Conv2D(1, 1, padding='same')(f6)
model = Model(inputs, output)
return model
# Load data
img = cv2.imread('images/E1.tif', cv2.IMREAD_UNCHANGED)
img = cv2.resize(img, (1408, 1408), interpolation=cv2.INTER_AREA)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.array(img, np.uint8)
# plt.imshow(img)
img3 = cv2.imread('images/E3.tif', cv2.IMREAD_UNCHANGED)
img3 = cv2.resize(img3, (1408, 1408), interpolation=cv2.INTER_AREA)
img3 = cv2.cvtColor(img3, cv2.COLOR_BGR2RGB)
img3 = np.array(img3, np.uint8)
# extract tiles from images
# idxs, tiles = extract_image_tiles(WIDTH, img)
# idxs2, tiles3 = extract_image_tiles(WIDTH, img3)
# extract tiles from images with overlap
center_tile_width = WIDTH # Adjust as needed
overlap_percentage = 60 # Adjust as needed
idxs, tiles, overlap = extract_image_tiles_with_overlap(WIDTH, WIDTH // 2, img, center_tile_width, overlap_percentage)
# split to train and test data
split_idx = int(tiles.shape[0] * 0.9)
train = tiles[:split_idx]
val = tiles[split_idx:]
y_train = tiles[:split_idx]
y_val = tiles[split_idx:]
# Build or load model
if os.path.exists(MODEL_SAVE_PATH):
model = tf.keras.models.load_model(MODEL_SAVE_PATH)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss=tf.keras.losses.Huber(),
metrics=[tf.keras.metrics.RootMeanSquaredError(name='rmse')])
# scale data before training
train = train / 255.
val = val / 255.
y_train = y_train / 255.
y_val = y_val / 255.
# train
history = model.fit(train,
y_train,
validation_data=(val, y_val),
epochs=0) # Adjust epochs as needed
# Save the model
model.save(MODEL_SAVE_PATH)
else:
model = build_model(HEIGHT, WIDTH, CHANNELS)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss=tf.keras.losses.Huber(),
metrics=[tf.keras.metrics.RootMeanSquaredError(name='rmse')])
# scale data before training
train = train / 255.
val = val / 255.
y_train = y_train / 255.
y_val = y_val / 255.
# train
history = model.fit(train,
y_train,
validation_data=(val, y_val),
epochs=50) # Adjust epochs as needed
# Save the model
model.save(MODEL_SAVE_PATH)
# predict on E2
img2 = cv2.imread('images/E2.tif', cv2.IMREAD_UNCHANGED)
img2 = cv2.resize(img2, (1408, 1408), interpolation=cv2.INTER_AREA)
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
img2 = np.array(img2, np.uint8)
# extract tiles from images
idxs, tiles2, overlap = extract_image_tiles_with_overlap(WIDTH, WIDTH // 2, img2, center_tile_width, overlap_percentage)
# scale data
tiles2 = tiles2 / 255.
preds = model.predict(tiles2)
# Check model output range
print("Max prediction value:", np.max(preds))
print("Min prediction value:", np.min(preds))
# Invert colors in predictions
inverted_preds = 1.0 - preds
# Ensure values are within valid range
inverted_preds = np.clip(inverted_preds, 0, 1)
# Reconstruct inverted predictions with cropping center part of the tile
reconstructed = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
# Reconstruction process
for (y_start, y_end, x_start, x_end), tile in zip(idxs, inverted_preds[:, :, :, -1]):
y_end = min(y_end, img.shape[0])
x_end = min(x_end, img.shape[1])
# Calculate the crop size based on the size of the original tile without overlap
crop_size = WIDTH
# Rescale the tile to the original size (WIDTH + overlap)
rescaled_tile = cv2.resize(tile, (WIDTH + overlap, WIDTH + overlap))
# Crop the center part of the rescaled tile
center_crop = rescaled_tile[overlap//2:overlap//2+crop_size, overlap//2:overlap//2+crop_size]
# Update the reconstructed image directly with the cropped tile
reconstructed[y_start:y_end, x_start:x_end] = (center_crop * 255).astype(np.uint8)
# Apply binary dilation and erosion to enhance the tile boundaries
# Apply non-local means denoising
denoised_reconstructed = cv2.fastNlMeansDenoising(reconstructed, h=10, templateWindowSize=7, searchWindowSize=21)
im = Image.fromarray(denoised_reconstructed)
im = im.resize(INIT_SIZE)
im.show()
The updated answer ends here.
So I think this is mainly about post processing and visualization of your data:
I visualize like this now:
# predict on E2
img2 = cv2.imread('images/E2.tif', cv2.IMREAD_UNCHANGED)
img2 = cv2.resize(img2, (1408, 1408), interpolation=cv2.INTER_AREA)
img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
img2 = np.array(img2, np.uint8)
# extract tiles from images
idxs, tiles2 = extract_image_tiles(WIDTH, img2)
# scale data
tiles2 = tiles2 / 255.
preds = model.predict(tiles2)
# Check model output range
print("Max prediction value:", np.max(preds))
print("Min prediction value:", np.min(preds))
# Invert colors in predictions
inverted_preds = 1.0 - preds
# Ensure values are within valid range
inverted_preds = np.clip(inverted_preds, 0, 1)
# Reconstruct inverted predictions
reconstructed = np.zeros((img.shape[0], img.shape[1]), dtype=np.uint8)
# Reconstruction process
for tile, (y_start, y_end, x_start, x_end) in zip(inverted_preds[:, :, :, -1], idxs):
y_end = min(y_end, img.shape[0])
x_end = min(x_end, img.shape[1])
reconstructed[y_start:y_end, x_start:x_end] = (tile * 255).astype(np.uint8)
im = Image.fromarray(reconstructed)
im = im.resize(INIT_SIZE)
im.show()
I can already clearly see the black square. So next I will be trying some thresholding to get enhance that visibility in post-processing.
I am thinking of something like that:
# Threshold value (adjust as needed)
threshold = 0.9 #0.45
# Thresholding
binary_output = (reconstructed >= threshold).astype(np.uint8) * 255
# Second visualization
im_binary = Image.fromarray(binary_output)
im_binary = im_binary.resize(INIT_SIZE)
im_binary.show()
Which leaves me with this:
Not sure how good this scales across your full dataset, but this is definitely in the ball-park for some morphological operators in post-processing.