Search code examples
algorithmdata-structureskdtree

KD TREES (3-D) Nearest Neighbour Search


I am looking at the Wikipedia page for KD trees Nearest Neighbor Search.

The pseudo code given in Wikipedia works when the points are in 2-D(x,y) .

I want to know,what changes should i make,when the points are 3-D(x,y,z).

I googled a lot and even went through similar questions link in stack overflow ,but i did n't find the 3-d implementation any where,all previous question takes 2-D points as input ,not the 3-D points that i am looking for.

The pseudo code in Wiki for building the KD Tree is::

function kdtree (list of points pointList, int depth)
{
    // Select axis based on depth so that axis cycles through all valid values
    var int axis := depth mod k;

    // Sort point list and choose median as pivot element
    select median by axis from pointList;

    // Create node and construct subtrees
    var tree_node node;
    node.location := median;
    node.leftChild := kdtree(points in pointList before median, depth+1);
    node.rightChild := kdtree(points in pointList after median, depth+1);
    return node;
}

How to find the Nearest neighbor now after building the KD Trees?

Thanks!


Solution

  • I've recently coded up a KDTree for nearest neighbor search in 3-D space and ran into the same problems understand the NNS, particularly 3.2 of the wiki. I ended up using this algorithm which seems to work in all my tests:

    Here is the initial leaf search:

    public Collection<T> nearestNeighbourSearch(int K, T value) {
        if (value==null) return null;
    
        //Map used for results
        TreeSet<KdNode> results = new TreeSet<KdNode>(new EuclideanComparator(value));
    
        //Find the closest leaf node
        KdNode prev = null;
        KdNode node = root;
        while (node!=null) {
            if (KdNode.compareTo(node.depth, node.k, node.id, value)<0) {
                //Greater
                prev = node;
                node = node.greater;
            } else {
                //Lesser
                prev = node;
                node = node.lesser;
            }
        }
        KdNode leaf = prev;
    
        if (leaf!=null) {
            //Used to not re-examine nodes
            Set<KdNode> examined = new HashSet<KdNode>();
    
            //Go up the tree, looking for better solutions
            node = leaf;
            while (node!=null) {
                //Search node
                searchNode(value,node,K,results,examined);
                node = node.parent;
            }
        }
    
        //Load up the collection of the results
        Collection<T> collection = new ArrayList<T>(K);
        for (KdNode kdNode : results) {
            collection.add((T)kdNode.id);
        }
        return collection;
    }
    

    Here is the recursive search which starts at the closest leaf node:

    private static final <T extends KdTree.XYZPoint> void searchNode(T value, KdNode node, int K, TreeSet<KdNode> results, Set<KdNode> examined) {
        examined.add(node);
    
        //Search node
        KdNode lastNode = null;
        Double lastDistance = Double.MAX_VALUE;
        if (results.size()>0) {
            lastNode = results.last();
            lastDistance = lastNode.id.euclideanDistance(value);
        }
        Double nodeDistance = node.id.euclideanDistance(value);
        if (nodeDistance.compareTo(lastDistance)<0) {
            if (results.size()==K && lastNode!=null) results.remove(lastNode);
            results.add(node);
        } else if (nodeDistance.equals(lastDistance)) {
            results.add(node);
        } else if (results.size()<K) {
            results.add(node);
        }
        lastNode = results.last();
        lastDistance = lastNode.id.euclideanDistance(value);
    
        int axis = node.depth % node.k;
        KdNode lesser = node.lesser;
        KdNode greater = node.greater;
    
        //Search children branches, if axis aligned distance is less than current distance
        if (lesser!=null && !examined.contains(lesser)) {
            examined.add(lesser);
    
            double nodePoint = Double.MIN_VALUE;
            double valuePlusDistance = Double.MIN_VALUE;
            if (axis==X_AXIS) {
                nodePoint = node.id.x;
                valuePlusDistance = value.x-lastDistance;
            } else if (axis==Y_AXIS) {
                nodePoint = node.id.y;
                valuePlusDistance = value.y-lastDistance;
            } else {
                nodePoint = node.id.z;
                valuePlusDistance = value.z-lastDistance;
            }
            boolean lineIntersectsCube = ((valuePlusDistance<=nodePoint)?true:false);
    
            //Continue down lesser branch
            if (lineIntersectsCube) searchNode(value,lesser,K,results,examined);
        }
        if (greater!=null && !examined.contains(greater)) {
            examined.add(greater);
    
            double nodePoint = Double.MIN_VALUE;
            double valuePlusDistance = Double.MIN_VALUE;
            if (axis==X_AXIS) {
                nodePoint = node.id.x;
                valuePlusDistance = value.x+lastDistance;
            } else if (axis==Y_AXIS) {
                nodePoint = node.id.y;
                valuePlusDistance = value.y+lastDistance;
            } else {
                nodePoint = node.id.z;
                valuePlusDistance = value.z+lastDistance;
            }
            boolean lineIntersectsCube = ((valuePlusDistance>=nodePoint)?true:false);
    
            //Continue down greater branch
            if (lineIntersectsCube) searchNode(value,greater,K,results,examined);
        }
    }
    

    The full java source can be found here.