I am trying to implement a KD-tree for use with DBSCAN. The problem is that I need to find all the neighbours of all points that meet a distance criteria. The problem is I don't get the same output when using the naive search (which is the desired output) as when I use the nearestNeighbours
method in my implementation. My implementation is adapted from a python implementation. Here's what I've got so far:
package dbscan_gui;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashSet;
public class Point {
final HashSet<Point> neighbours = new HashSet<Point>();
int[] points;
boolean visited = false;
public Point(int... is) {
this.points = is;
public String toString() {
return Arrays.toString(points);
public double squareDistance(Point p) {
double sum = 0;
for (int i = 0;i < points.length;i++) {
sum += Math.pow(points[i] - p.points[i],2);
return sum;
public double distance(Point p) {
return Math.sqrt(squareDistance(p));
public void addNeighbours(ArrayList<Point> ps) {
public void addNeighbour(Point p) {
if (p != this)
package dbscan_gui;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.TreeSet;
public class KDTree {
KDTreeNode root;
PointComparator[] comps;
public KDTree(ArrayList<Point> list) {
int axes = list.get(0).points.length;
comps = new PointComparator[axes];
for (int i = 0; i < axes; i++) {
comps[i] = new PointComparator(i);
root = new KDTreeNode(list,0);
private class PointComparator implements Comparator<Point> {
private int axis;
public PointComparator(int axis) {
this.axis = axis;
public int compare(Point p1, Point p2) {
return p1.points[axis] - p2.points[axis];
* Adapted from https://code.google.com/p/python-kdtree/
* Stores points in a tree, sorted by axis
public class KDTreeNode {
KDTreeNode leftChild = null;
KDTreeNode rightChild = null;
Point location;
public KDTreeNode(ArrayList<Point> list, int depth) {
final int axis = depth % (list.get(0).points.length);
Collections.sort(list, comps[axis] );
int median = list.size()/2;
location = list.get(median);
List<Point> leftPoints = list.subList(0, median);
List<Point> rightPoints = list.subList(median+1, list.size());
leftChild = new KDTreeNode(new ArrayList<Point>(leftPoints), depth+1);
rightChild = new KDTreeNode(new ArrayList<Point>(rightPoints),depth+1);
* @return true if this node has no children
public boolean isLeaf() {
return leftChild == null && rightChild == null;
* Finds the nearest neighbours of a point that fall within a given distance
* @param queryPoint the point to find the neighbours of
* @param epsilon the distance threshold
* @return the list of points
public ArrayList<Point> nearestNeighbours(Point queryPoint, int epsilon) {
KDNeighbours neighbours = new KDNeighbours(queryPoint);
nearestNeighbours_(root, queryPoint, 0, neighbours);
return neighbours.getBest(epsilon);
* @param node
* @param queryPoint
* @param depth
* @param bestNeighbours
private void nearestNeighbours_(KDTreeNode node, Point queryPoint, int depth, KDNeighbours bestNeighbours) {
if(node == null)
if(node.isLeaf()) {
int axis = depth % (queryPoint.points.length);
KDTreeNode nearSubtree = node.rightChild;
KDTreeNode farSubtree = node.leftChild;
if(queryPoint.points[axis] < node.location.points[axis]) {
nearSubtree = node.leftChild;
farSubtree = node.rightChild;
nearestNeighbours_(nearSubtree, queryPoint, depth+1, bestNeighbours);
if(node.location != queryPoint)
if(Math.pow(node.location.points[axis] - queryPoint.points[axis],2) <= bestNeighbours.largestDistance)
nearestNeighbours_(farSubtree, queryPoint, depth+1,bestNeighbours);
* Private datastructure for holding the neighbours of a point
private class KDNeighbours {
Point queryPoint;
double largetsDistance = 0;
TreeSet<Tuple> currentBest = new TreeSet<Tuple>(new Comparator<Tuple>() {
public int compare(Tuple o1, Tuple o2) {
return (int) (o1.y-o2.y);
KDNeighbours(Point queryPoint) {
this.queryPoint = queryPoint;
public ArrayList<Point> getBest(int epsilon) {
ArrayList<Point> best = new ArrayList<Point>();
Iterator<Tuple> it = currentBest.iterator();
while(it.hasNext()) {
Tuple t =it.next();
if(t.y > epsilon*epsilon)
else if(t.x != queryPoint)
return best;
public void add(Point p) {
currentBest.add(new Tuple(p, p.squareDistance(queryPoint)));
largestDistance = currentBest.last().y;
private class Tuple {
Point x;
double y;
Tuple(Point x, double y) {
this.x = x;
this.y = y;
public static void main(String[] args) {
int epsilon = 3;
System.out.println("Epsilon: "+epsilon);
ArrayList<Point> points = new ArrayList<Point>();
Random r = new Random();
for (int i = 0; i < 10; i++) {
points.add(new Point(r.nextInt(10), r.nextInt(10)));
System.out.println("Points "+points );
System.out.println("Neighbouring Kd");
KDTree tree = new KDTree(points);
for (Point p : points) {
ArrayList<Point> neighbours = tree.nearestNeighbours(p, epsilon);
for (Point q : neighbours) {
System.out.println("Neighbouring O(n^2)");
for (int i = 0; i < points.size(); i++) {
for (int j = i + 1; j < points.size(); j++) {
Point p = points.get(i), q = points.get(j);
if (p.distance(q) <= epsilon) {
for (Point point : points) {
When I run this I get the following output (the latter part being the model output):
Epsilon: 3
Points [[9, 5], [4, 7], [3, 1], [0, 0], [5, 7], [0, 1], [5, 5], [1, 2], [9, 2], [9, 9]]
Neighbouring Kd
Neighbours of [0, 0] are: [[0, 1]]
Neighbours of [0, 1] are: [[1, 2], [0, 0], [3, 1]]
Neighbours of [1, 2] are: [[0, 1], [3, 1]]
Neighbours of [3, 1] are: [[0, 1], [1, 2]]
Neighbours of [4, 7] are: [[5, 7]]
Neighbours of [5, 7] are: [[4, 7]]
Neighbours of [5, 5] are: [[4, 7], [5, 7]]
Neighbours of [9, 5] are: [[9, 2]]
Neighbours of [9, 2] are: [[9, 5]]
Neighbours of [9, 9] are: []
Neighbouring O(n^2)
Neighbours of [0, 0] are: [[0, 1], [1, 2]]
Neighbours of [0, 1] are: [[1, 2], [0, 0], [3, 1]]
Neighbours of [1, 2] are: [[0, 1], [0, 0], [3, 1]]
Neighbours of [3, 1] are: [[0, 1], [1, 2]]
Neighbours of [4, 7] are: [[5, 5], [5, 7]]
Neighbours of [5, 7] are: [[4, 7], [5, 5]]
Neighbours of [5, 5] are: [[4, 7], [5, 7]]
Neighbours of [9, 5] are: [[9, 2]]
Neighbours of [9, 2] are: [[9, 5]]
Neighbours of [9, 9] are: []
I can't figure out why the neighbours aren't the same, it seems that it can find that a->b is a neighbouring, but not that b->a is also a neighbouring.
You may want to use ELKI which includes DBSCAN and index structures such as the R*-tree for nearest neighbors search. When parameterized right, it's really really fast. I saw in the trac that the next version will also have a KD-tree.
From a quick look at your code, I have to agree with @ThomasJungblut - you do not backtrack and then try the other branch as necessary, which is why you miss a lot of neighbors. You may need to look at both branches!