Search code examples
python3dnearest-neighbor

How to find at most 6 nearest neighbors in 3D space?


I need to find at most 6 (depending on the node position) nearest neighbors for each point in 3D space. If you imagine cube 3x3x3, each corner will have 3 neighbors, each edge 4, each point near to the wall 5 and each point inside that cube will have 6 neighbors, but in general I need to find nearest neighbor from each side (left, right, top, bottom, front and back) if only exists.

I can achieve that by dividing space for each point and searching for the nearest point in each bucket. But below code is way too slow for larger data (it would take ~250 minutes for ~70k points).

Example data:

points = [
    (0, 0, 0),
    (0, 0, 1),
    (0, 0, 2),
    (0, 1, 0),
    (0, 1, 1),
    (0, 1, 2),
    (0, 2, 0),
    (0, 2, 1),
    (0, 2, 2),
    (1, 0, 0),
    (1, 0, 1),
    (1, 0, 2),
    (1, 1, 0),
    (1, 1, 1),
    (1, 1, 2),
    (1, 2, 0),
    (1, 2, 1),
    (1, 2, 2),
    (2, 0, 0),
    (2, 0, 1),
    (2, 0, 2),
    (2, 1, 0),
    (2, 1, 1),
    (2, 1, 2),
    (2, 2, 0),
    (2, 2, 1),
    (2, 2, 2)
]

Grouping points:

for i, (x1, y1, z1) in enumerate(points):
    f_xy = lambda x, y: (x - x1) + (y - y1)
    g_xy = lambda x, y: (x - x1) - (y - y1)
    f_xz = lambda x, z: (x - x1) + (z - z1)
    g_xz = lambda x, z: (x - x1) - (z - z1)
    f_yz = lambda y, z: (y - y1) + (z - z1)
    g_yz = lambda y, z: (y - y1) - (z - z1)

    groups = {'right': [], 'left': [], 'front': [], 'back': [], 'top': [], 'bottom': []}

    for j, (x2, y2, z2) in enumerate(points):
        if i != j:
            if f_xy(x2, y2) >= 0 and g_xy(x2, y2) >= 0 and f_xz(x2, z2) >= 0 and g_xz(x2, z2) >= 0: groups['right'].append((x2, y2, z2))
            if f_xy(x2, y2) <= 0 and g_xy(x2, y2) <= 0 and f_xz(x2, z2) <= 0 and g_xz(x2, z2) <= 0: groups['left'].append((x2, y2, z2))


            if f_xy(x2, y2) > 0 and g_xy(x2, y2) < 0 and f_yz(y2, z2) >= 0 and g_yz(y2, z2) >= 0: groups['front'].append((x2, y2, z2))
            if f_xy(x2, y2) < 0 and g_xy(x2, y2) > 0 and f_yz(y2, z2) <= 0 and g_yz(y2, z2) <= 0: groups['back'].append((x2, y2, z2))


            if f_xz(x2, z2) < 0 and g_xz(x2, z2) > 0 and f_yz(y2, z2) < 0 and g_yz(y2, z2) > 0: groups['bottom'].append((x2, y2, z2))
            if f_xz(x2, z2) > 0 and g_xz(x2, z2) < 0 and f_yz(y2, z2) > 0 and g_yz(y2, z2) < 0: groups['top'].append((x2, y2, z2))


    print(f'Vertex: {(x1, y1, z1)}')
    print(len(groups['right']) + len(groups['left']) + len(groups['front']) + len(groups['back']) + len(groups['bottom']) + len(groups['top']))
    print('Left', groups['left'])
    print('Right', groups['right'])
    print('Front', groups['front'])
    print('Back', groups['back'])
    print('Top', groups['top'])
    print('Bottom', groups['bottom'])
    print('\n')

Output:

Vertex: (0, 0, 0)
26
Left []
Right [(1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1), (2, 0, 0), (2, 0, 1), (2, 0, 2), (2, 1, 0), (2, 1, 1), (2, 1, 2), (2, 2, 0), (2, 2, 1), (2, 2, 2)]
Front [(0, 1, 0), (0, 1, 1), (0, 2, 0), (0, 2, 1), (0, 2, 2), (1, 2, 0), (1, 2, 1), (1, 2, 2)]
Back []
Top [(0, 0, 1), (0, 0, 2), (0, 1, 2), (1, 0, 2), (1, 1, 2)]
Bottom []


Vertex: (0, 0, 1)
26
Left []
Right [(1, 0, 0), (1, 0, 1), (1, 0, 2), (1, 1, 0), (1, 1, 1), (1, 1, 2), (2, 0, 0), (2, 0, 1), (2, 0, 2), (2, 1, 0), (2, 1, 1), (2, 1, 2), (2, 2, 0), (2, 2, 1), (2, 2, 2)]
Front [(0, 1, 0), (0, 1, 1), (0, 1, 2), (0, 2, 0), (0, 2, 1), (0, 2, 2), (1, 2, 0), (1, 2, 1), (1, 2, 2)]
Back []
Top [(0, 0, 2)]
Bottom [(0, 0, 0)]


Vertex: (0, 0, 2)
26
Left []
Right [(1, 0, 1), (1, 0, 2), (1, 1, 1), (1, 1, 2), (2, 0, 0), (2, 0, 1), (2, 0, 2), (2, 1, 0), (2, 1, 1), (2, 1, 2), (2, 2, 0), (2, 2, 1), (2, 2, 2)]
Front [(0, 1, 1), (0, 1, 2), (0, 2, 0), (0, 2, 1), (0, 2, 2), (1, 2, 0), (1, 2, 1), (1, 2, 2)]
Back []
Top []
Bottom [(0, 0, 0), (0, 0, 1), (0, 1, 0), (1, 0, 0), (1, 1, 0)]

...

I tried to use KDTrees, but if points are located unevenly it may return 6 neighbors from the left side.

Do you have any ideas? Thanks!


Solution

  • I think a way to solve the problem of finding nearest neighbors is using numpy to calculate distances between points, sorting and get indices. For this you will have to transform your list of points to a numpy array with points=np.array(points). For the sake of this example i will create my own points.

    
    #CREATE POINTS
    import numpy as np
    points = np.random.randn(100,3)
    

    enter image description here

    #Find distances and nearests:
    #select target point to find nearest
    target_point = points[10]
    n_nearest = 6
    
    #calc distance
    distances = ((points-target_point)**2).sum(1)**0.5
    #find nearest
    nearest_p = points[np.argsort(distances)[1:1+n_nearest]]
    
    

    enter image description here

    For last create a function to find nearest for one point and measure time for that:

    
    def find_n_nearest(points,target_point,n_nearest):
        distances = ((points-target_point)**2).sum(1)**0.5
        nearest_p = points[np.argsort(distances)[1:1+n_nearest]]
        return nearest_p
    
    n = int(2e4)
    p = np.random.randn(n,3)
    n_nearest= 6
    
    
    #measure time
    import time
    
    t0 = time.time()
    for i,target_point in enumerate(p):
        if not i%1000:
            print(i)
        find_n_nearest(p,target_point,n_nearest)
    t1 = time.time()
    print(t1-t0)
    
    

    wich throws:

            82.79649305343628 seconds for 30k points (3e4)
    
            35.90334725379944 seconds for 20k
    

    As I said in comments there's no magic way, it's a time consumption or a memory consumption problem, with 70k it's hard to caching distance values without overflow your memory fast (~2.4e9 floats if you use half a matrix to store values)


    EDIT For 70k (7e4 points) it took: 478.503 seconds


    EDIT 2

    added detect a point in every direction, then tested it with 70k points and took 594.5 seconds

    
    def find_nearest_dir(points,target_point):
        distances = ((points-target_point)**2).sum(1)**0.5
        sides = [
                    points[:,0] > target_point[0],#right
                    points[:,0] < target_point[0],#left
                    points[:,1] > target_point[1],#front
                    points[:,1] < target_point[1],#back
                    points[:,2] > target_point[2],#top
                    points[:,2] < target_point[2],#bottom                
                ]
        nearest = []
        for s in sides:
            dif_filt = distances[s]
            poi_filt = points[s]
            if len(dif_filt):
                nearest.append(poi_filt[np.argmin(dif_filt)])
            else:
                nearest.append([])
        return nearest
        
        
    n =int(7e4)
    p = np.random.randn(n,3)
    n_nearest= 6
    
    
    #measure time
    import time
    
    t0 = time.time()
    q=[]
    for i,target_point in enumerate(p):
        if not i%1000:
            print(i)
        q = find_nearest_dir(p,target_point)
    t1 = time.time()
    print(t1-t0)