Search code examples
pythontensorflowopencvconv-neural-networkdensity-plot

Density plot to predict object count


I'm trying to count the number of objects (larvae in this case), from a video.

This being the overall goal, I first tried density map on all the frames from the video. On looking at density plots more closely, I notice the algorithm is counting objects which are not even larvae, so I thought let me narrow down the question to just one frame at the moment. In the frames, I applied the k-means cluster on the colours present in the image. After some manipulation, I understand, what is the colour of the larvae and mask the images for the colour. Post these modification I apply density map on masked image for the colour of interest.

It now happens that the count is of the order 1e3, but in reality this is wrong as actual count ~50. When looking closely at the density maps, it is predicting/ counting the density of point objects and this is the reason for wrong calculations.

Now the question is How do I change the density map such that it doesn't count point objects, but give me a closer answer for my actual count.

The code for density map is:

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Input
from tensorflow.keras.models import Model
from tensorflow.keras.losses import MeanSquaredError
import cv2

# Create the density map estimation model
def create_density_map_model(input_shape):
    inputs = Input(shape=input_shape)
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(inputs)
    x = Conv2D(64, (3, 3), activation='relu', padding='same')(x)
    density_map = Conv2D(1, (1, 1), activation='linear', padding='same')(x)
    model = Model(inputs=inputs, outputs=density_map)
    return model

# Load the image
image_path = 'color_17.png'
image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)

# Extract the alpha channel from the image
alpha_channel = image[:, :, 3]

# Threshold the alpha channel to get a binary mask where non-transparent pixels are white (255) and transparent pixels are black (0)
_, binary_mask = cv2.threshold(alpha_channel, 0, 255, cv2.THRESH_BINARY)

# Get the height and width of the image
image_height, image_width = image.shape[:2]

# Create and compile the model
input_shape = (image_height, image_width, 1)
model = create_density_map_model(input_shape)
model.compile(optimizer='adam', loss=MeanSquaredError())

# Preprocess the image
image_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image_normalized = image_gray / 255.0
image_input = np.expand_dims(image_normalized, axis=-1)
image_input = np.expand_dims(image_input, axis=0)

# Predict the density map for the image
predicted_density_map = model.predict(image_input)

# Reshape the predicted density map to match the shape of the image
predicted_density_map = np.reshape(predicted_density_map, (image_height, image_width, 1))

# Create a size factor map (assuming all pixels have the same size factor for simplicity)
size_factor_map = np.ones_like(predicted_density_map)

# Apply the size factor map to the predicted density map
predicted_density_map_filtered = predicted_density_map * size_factor_map

# Expand the dimensions of the binary mask to match the number of channels in the predicted density map
binary_mask_reshaped = np.expand_dims(binary_mask, axis=-1)

# Apply the binary mask to the filtered density map to only keep values where the image isn't transparent
predicted_density_map_filtered = predicted_density_map_filtered * (binary_mask_reshaped / 255.0)

## Apply the binary mask to the filtered density map to only keep values where the image isn't transparent
#predicted_density_map_filtered = predicted_density_map_filtered * (binary_mask / 255.0)

# Save the filtered density map as an image
cv2.imwrite('density_plot_filtered.jpg', (predicted_density_map_filtered * 255.0).astype(np.uint8))

# Calculate the number of larvae predicted from the filtered density map
num_larvae_predicted = int(np.sum(predicted_density_map_filtered))

# Print the number of larvae predicted
print(f"Number of larvae predicted: {num_larvae_predicted}")

The video that I used for the first code, i.e., the kmeans for identifying the top 20 expressed colours can be found here: https://iitk-my.sharepoint.com/:v:/g/personal/abint21_iitk_ac_in/EexjUqun0pZFmzRTxTBHGFoB8ML2hX5iZ6luH9QVWpLjKA?e=mLQLfb

The first frame without any filtering can be found here: enter image description here

To add, the colour of interest when viewed with rgba palette looks like this: enter image description here

The density plot for this image comes out as something like this: enter image description here

I hope this information was helpful for making the details clear. Please let me know in case more details are needed.


Solution

  • try this code

    import cv2
    import numpy as np
    from sklearn.cluster import KMeans
    import matplotlib.pyplot as plt
    
    # Load the image
    image = cv2.imread('larvae_image.jpg', cv2.IMREAD_GRAYSCALE)
    
    # Apply Gaussian blur to reduce noise
    blurred = cv2.GaussianBlur(image, (5, 5), 0)
    
    # Perform adaptive thresholding to segment larvae
    thresh = cv2.adaptiveThreshold(blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY_INV, 11, 2)
    
    # Find contours of the segmented larvae
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # Extract features (area and perimeter) from contours
    features = []
    for contour in contours:
        area = cv2.contourArea(contour)
        perimeter = cv2.arcLength(contour, True)
        features.append([area, perimeter])
    
    # Perform K-means clustering to group larvae
    kmeans = KMeans(n_clusters=2, random_state=42).fit(features)
    labels = kmeans.labels_
    
    # Count the number of larvae in the desired cluster
    larvae_count = np.count_nonzero(labels == 1)  # Assuming cluster 1 represents larvae
    
    # Visualize the results
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.imshow(image, cmap='gray')
    plt.title('Original Image')
    plt.axis('off')
    plt.subplot(1, 2, 2)
    for contour, label in zip(contours, labels):
        color = (0, 255, 0) if label == 1 else (0, 0, 255)  # Green for larvae, Red for non-larvae
        cv2.drawContours(image, [contour], 0, color, 2)
    plt.imshow(cv2.cvtColor(image, cv2.COLOR_GRAY2RGB))
    plt.title(f'Larvae Count: {larvae_count}')
    plt.axis('off')
    plt.tight_layout()
    plt.show()
    

    slightly altered version for cluste counting

    import cv2
    import numpy as np
    from sklearn.cluster import DBSCAN
    from skimage.filters import frangi
    from skimage.morphology import skeletonize, remove_small_objects
    import matplotlib.pyplot as plt
    
    image = cv2.imread('larvae_image.jpg', cv2.IMREAD_GRAYSCALE)
    
    frangi_filtered = frangi(image, sigmas=range(1, 10, 2))
    
    skeleton = skeletonize(frangi_filtered > 0.5)
    
    skeleton = remove_small_objects(skeleton, min_size=64, connectivity=2)
    
    features = []
    labels, num_features = cv2.connectedComponents(skeleton.astype(np.uint8))
    for label in range(1, num_features):
        component = (labels == label).astype(np.uint8)
        area = np.sum(component)
        perimeter = cv2.arcLength(cv2.findContours(component, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)[0][0], True)
        skeleton_length = np.sum(component * skeleton)
        features.append([area, perimeter, skeleton_length])
    
    dbscan = DBSCAN(eps=0.5, min_samples=5).fit(features)
    labels = dbscan.labels_
    
    larvae_count = np.count_nonzero(labels != -1) 
    
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 3, 1)
    plt.imshow(image, cmap='gray')
    plt.title('Original Image')
    plt.axis('off')
    plt.subplot(1, 3, 2)
    plt.imshow(skeleton, cmap='gray')
    plt.title('Skeletonized Image')
    plt.axis('off')
    plt.subplot(1, 3, 3)
    colored_image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
    for label in set(labels):
        if label == -1:
            continue
        mask = (labels == label).astype(np.uint8) * 255
        contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        color = (np.random.randint(0, 256), np.random.randint(0, 256), np.random.randint(0, 256))
        cv2.drawContours(colored_image, contours, -1, color, 2)
    plt.imshow(colored_image)
    plt.title(f'Larvae Count: {larvae_count}')
    plt.axis('off')
    plt.tight_layout()
    plt.show()