Search code examples
javaalgorithmheappriority-queue

K nearest neighbour in a 2d plane


The Question statement is as follows:

Find the K closest points to the origin in a 2D plane, given an array containing N points.The output must be in non decreasing order.

Solution: I have solved this using comparator and priority queue and my code looks like below:

class Point {
    double x;
    double y;

    public Point(double x, double y) {
        this.x = x;
        this.y = y;
    }
}

public class KClosestPoints {
    public Point[] getKNearestPoints(Point[] points, int k) {
        if (k == 0 || points.length == 0) {
            return new Point[0];
        }
        Point[] rValue = new Point[k];
        int index = k - 1;
        if (points.length < k) {
            index = points.length - 1;
        }
        final Point org = new Point(0, 0);
        PriorityQueue<Point> pq = new PriorityQueue<Point>(k,
                new Comparator<Point>() {
                    @Override
                    public int compare(Point o1, Point o2) {
                    Double d2 = getDistance(o2, org);
                    Double d1 = getDistance(o1, org);
                    if (d2 > d1) {
                        return 1;
                    } else if (d2 < d1) {
                        return -1;
                    } else
                        return 0;
                }
                });
        for (int i = 0; i < points.length; i++) {
            pq.offer(points[i]);
            if (pq.size() > k) {
                pq.poll();
            }
        }
        while (!pq.isEmpty()) {
            rValue[index] = pq.poll();
            index--;
        }
        return rValue;
    }

    private static double getDistance(Point a, Point b) {
        return Math.sqrt(((a.x - b.x) * (a.x - b.x))
                + ((a.y - b.y) * (a.y - b.y)));
    }

My Code works for all test cases i have used except this:

test6[0] = new Point(Double.MIN_VALUE, Double.MAX_VALUE);
test6[1] = new Point(Double.MIN_VALUE, Double.MIN_VALUE);
test6[2] = new Point(Double.MAX_VALUE, Double.MAX_VALUE);
getKNearestPoints(test6, 2);

The answer should be test6[0] and test6[1] whereas this code gives answer as test6[0] and test6[2]. Help me in finding the issue.

EDIT: Later i have noticed it is giving wrong answer in all the test cases with k = 2 when the points are are positive and negative and one such test case is mentioned above


Solution

  • For the issue you are asking, there is a problem calculating the distance:

    Point # 0 = (0, 0)
    Point # 1 = (Double.MIN_VALUE, Double.MAX_VALUE) -> (4.9E-324, 1.7E308)
    Point # 2 = (Double.MIN_VALUE, Double.MIN_VALUE) -> (4.9E-324, 4.9E-324)
    Point # 3 = (Double.MAX_VALUE, Double.MAX_VALUE) -> (1.7E308, 1.7E308)
    

    Note: Double.MIN_VALUE is a positive number.

    Now the euclidean distance d = Math.sqrt(((a.x - b.x) * (a.x - b.x)) + ((a.y - b.y) * (a.y - b.y))); return Infinity when calculating distance between Point # 0 and Point # 1, and between Point # 0 and Point # 3, because:

    Point # 1

    (a.x - b.x) * (a.x - b.x) = (1.7E308 - 0) * (1.7E308 - 0) = 1.7E308 * 1.7E308 = Infinity
    Math.sqrt(Infinity + Infinity) = Infinity;
    

    After getting the distance of Point # 1 and Point # 0 (Infinity), and then comparing it with the distance of Point # 3 and Point # 0 (Also Infinity) then Infinity = Infinity is true, hence the Comparator says "Both Points are equal", and the PriorityQueue don't order as you wish.

    For the operation with Double.MAX_VALUE you must not use Double, instead use BigDecimal:

    private BigDecimal getDistance(Point a, Point b) {
        BigDecimal dx = BigDecimal.valueOf(a.x - b.x);
        BigDecimal dy = BigDecimal.valueOf(a.y - b.y);
        BigDecimal distance = dx.pow(2).add(dy.pow(2));
        return distance;
    }
    

    Why are we not calculating the real distance with square root? Because:

    • BigDecimal class doesn't has such method (if you which here you can).
    • Because if a > b then sqrt(a) > sqrt(b), its proportional, so it's enough get the value without apply square root function.

    Taking advantage of BigDecimal, it implements its own Comparator so we use it in the compareTo method of the Comparator defined:

    @Override
    public int compare(Point o1, Point o2) {
        BigDecimal d1 = getDistance(o1);
        BigDecimal d2 = getDistance(o2);
        return d1.compareTo(d2);
    }