Search code examples
pythonarraysnumpyimage-processingvectorization

vectorizing custom python function with numpy array


Not sure if that is the correct terminology. Basically trying to take a black and white image and first transform it such that all the white pixels that border black-pixels remain white, else turn black. That part of the program works fine, and is done in find_edges. Next I need to calculate the distance from each element in the image to the closest white-pixel. Right now I am doing it by using a for-loop that is insanely slow. Is there a way to make the find_nearest_edge function written solely with numpy without the need for a for-loop to call it on each element? Thanks.

####

from PIL import Image
import numpy as np
from scipy.ndimage import binary_erosion

####

def find_nearest_edge(arr, point):
    w, h = arr.shape
    x, y = point
    xcoords, ycoords = np.meshgrid(np.arange(w), np.arange(h))

    target = np.sqrt((xcoords - x)**2 + (ycoords - y)**2)
    target[arr == 0] = np.inf

    shortest_distance = np.min(target[target > 0.0])

    return shortest_distance

def find_edges(img):
    img = img.convert('L')
    img_np = np.array(img)

    kernel = np.ones((3,3))
    edges = img_np - binary_erosion(img_np, kernel)*255

    return edges

a = Image.open('a.png')
x, y = a.size

edges = find_edges(a)

out = Image.fromarray(edges.astype('uint8'), 'L')
out.save('b.png')

dists =[]
for _x in range(x):
    for _y in range(y):
        dist = find_nearest_edge(edges,(_x,_y))
        dists.append(dist)

print(dists)

Images:

enter image description here

enter image description here


Solution

  • You can use KDTree to compute distances fast.

    import numpy as np
    import matplotlib.pyplot as plt
    
    from scipy.ndimage import binary_erosion
    from scipy.spatial import KDTree
    
    
    def find_edges(img):
        img_np = np.array(img)
    
        kernel = np.ones((3,3))
        edges = img_np - binary_erosion(img_np, kernel)*255
    
        return edges
    
    
    def find_closest_distance(img):
        # NOTE: assuming input is binary image and white is any non-zero value!
        white_pixel_points = np.array(np.where(img))
        tree = KDTree(white_pixel_points.T)
        img_meshgrid = np.array(np.meshgrid(np.arange(img.shape[0]), np.arange(img.shape[1]))).T
        distances, _ = tree.query(img_meshgrid)
        return distances
    
    test_image = np.zeros((200, 200))
    rectangle = np.ones((30, 80))
    test_image[20:50, 60:140] = rectangle
    test_image[150:180, 60:140] = rectangle
    test_image[60:140, 20:50] = rectangle.T
    test_image[60:140, 150:180] = rectangle.T
    test_image = test_image * 255
    edge_image = find_edges(test_image)
    distance_image = find_closest_distance(edge_image)
    
    
    fig, axes = plt.subplots(1, 3, figsize=(12, 5))
    axes[0].imshow(test_image, cmap='Greys_r')
    axes[1].imshow(edge_image, cmap='Greys_r')
    axes[2].imshow(distance_image, cmap='Greys_r')
    plt.show()
    

    enter image description here