Search code examples
pythonimage-processingpytorchcluster-analysisknn

How to cluster PyTorch predictions


I'm trying to find road lanes from road images and then make predictions out of the images. So far, I've trained a model that finds road lanes. But most of the predictions are scattered. I'm trying to cluster PyTorch predictions that we get from these road images. These dots are the predictions of model where the road lanes might be.

Predictions shape: [1, 1, 80, 120]

Here's the image of predictions:

enter image description here

Here's what I want to achieve (I edited the image, deleted the dots that are scattered):

enter image description here

As you can see, I deleted the dots (predictions) from the image. I want each dot to be clustered with each other. How can I achieve this? I tried KNN (K Nearest Neighbors) but it didn't work.


Solution

  • If you only want to remove dots then you can try to use morphological operations such as Opening (erode+dilate) to postprocess your mask.

    The resulting mask without dots:

    enter image description here

    Code:

    import cv2
    import numpy as np
    
    mask = cv2.imread('road_mask.jpg', cv2.IMREAD_GRAYSCALE)
    mask = cv2.resize(mask, (120, 80))
    
    mask = cv2.erode(mask, np.ones((2, 2)))
    mask = cv2.dilate(mask, np.ones((3, 3)))
    mask = ((mask > 10) * 255).astype(np.uint8)
    
    cv2.imwrite("postprocessed_mask.png", mask)