Search code examples
pythonnumpyscipykdtree

Optimize search of closest four elements in two 3D arrays


I have two numpy arrays filled with 3D coordinates (x, y, z). For each point of the first array (the "target" array) I need to find the 4 closest points of the 2nd array (the "source" array). I have no problem finding the actual results using different methods, but I want to speed up the process as much as I can.

I need this because I am working on a Maya tool that transfers information stored in each vertex of a mesh to a second mesh, and they might have different number of vertices.

At this point though, it becomes more of a python problem than a Maya one since my main bottleneck is the time spent looking for the vertex matches.

The number of elements can vary from a few hundreds to hundred of thousands, and I want to make sure that I find the best way to speed up the search. I would like my tool to be as quick as possible, since it might be used really often, and waiting minutes every single time it has to run would be quite annoying.

I found some useful answers that got me in the right direction:

Here I found out about KDTrees and different algorithms and here I found some useful considerations on multithreading.

Here is some code that simulates the kind of scenario I would be working with and a few of the solutions I tried.

import timeit
import numpy as np
from multiprocessing.pool import ThreadPool
from scipy import spatial

# brut Froce
def bruteForce():
    results = []
    for point in sources:
        dists = ((targets - [point]) ** 2).sum(axis=1)  # compute distances
        ndx = dists.argsort()  # indirect sort
        results.append(zip(ndx[:4], dists[ndx[:4]]))
    return results

# Thread Pool Implementation
def threaded():
    def worker(point):
        dists = ((targets - [point]) ** 2).sum(axis=1)  # compute distances
        ndx = dists.argsort()  # indirect sort
        return zip(ndx[:4], dists[ndx[:4]])


    pool = ThreadPool()
    return pool.map(worker, sources)

# KDTree implementation
def kdTree():
    tree = spatial.KDTree(targets, leafsize=50)
    return [tree.query(point, k=4) for point in sources]

# define the number of points for the two arrays
n_targets = 40000  
n_sources = 40000  

#pick some random points
targets = np.random.rand(n_targets, 3) * 100
sources = np.random.rand(n_sources, 3) * 100



print 'KDTree:   %s' % timeit.Timer(lambda: kdTree()).repeat(1, 1)[0]
print 'bruteforce:   %s' % timeit.Timer(lambda: bruteForce()).repeat(1, 1)[0]
print 'threaded:   %s' % timeit.Timer(lambda: threaded()).repeat(1, 1)[0]

My results are:

KDTree:       10.724864464  seconds
bruteforce:   211.427750433 seconds
threaded:     47.3280865123 seconds

The most promising method seems the KDTree. At first I thought that by using some Threads to split the work of the KDTree into separate tasks, I could speed up the process even more. However, after testing quickly using a basic threading.Thread implementation, it seemed to perform even worse when the KDTree was being computed in a Thread. Reading this scipy example I can see that KDTrees are not really suitable to be used in parallel Threads, but I did not really understood way.

I was wondering, then, if there is any other way I could optimize this code to perform quicker, maybe by using multiprocessing or some other kind of trick to parse through my arrays in parallel.

Thanks in advance for the help!


Solution

  • There is one very simple but extremely effective thing you can do which is switching from KDTree to cKDTree. The latter being a Cython drop-in replacement of the first which is implemented in pure Python.

    Also note that .query is vectorized, no need for a list comprehension.

    import scipy.spatial as ss
    
    a = np.random.random((40000,3))
    b = np.random.random((40000,3))
    
    tree_py = ss.KDTree(a)
    tree_cy = ss.cKDTree(a)
    
    timeit(lambda: tree_cy.query(b, k=4), number=10)*100
    # 71.06744810007513
    timeit(lambda: tree_py.query(b, k=4), number=1)*1000
    # 13309.359921026044
    

    So that is an almost 200x speedup for free.