Search code examples
javaheapcomparatorpriority-queue

Implement a custom comparator for my PriorityQueue


I'm trying to solve the following leetcode problem:

Given a sorted array, two integers k and x, find the k closest elements to x in the array. The result should also be sorted in ascending order. If there is a tie, the smaller elements are always preferred.

Example 1: Input: [1,2,3,4,5], k=4, x=3

Output: [1,2,3,4]

Example 2: Input: [1,2,3,4,5], k=4, x=-1

Output: [1,2,3,4]

My incorrect solution for now is the following:

class Solution {
    public List<Integer> findClosestElements(int[] arr, int k, int x) {
        PriorityQueue<Integer> pq = new PriorityQueue<>(arr.length, (a,b) -> a == b ? a - b : Math.abs(a-x) - Math.abs(b-x));

       for(int i=0; i<arr.length; i++) {
           pq.add(arr[i]);

       }



        ArrayList ints = new ArrayList<>();

        for(int i=0;i<k;i++) {
            ints.add(pq.poll());


        }
        return ints;
    }
}

The problem is with the comparator I'm passing to the constructor. The idea is that I want my comparator to sort the integers with respect to the minimum distance between any integer i and the input x and then poll k elements from the queue. How can I impelement a comparator function that sorts the elements that way?


Solution

  • I would take advantage of the default Integer.compare method. Basically what you want is to first check the compare of the absolute difference, and if its a tie do a normal compare.

    static int compare(int x, int a, int b) {
        int comp = Integer.compare(Math.abs(a - x), Math.abs(b - x));
        if (comp == 0) {
            return Integer.compare(a, b);
        }
        return comp;
    }
    

    This makes it pretty clean to write the actual priority queue implementation

    static List<Integer> findClosestElements(int[] arr, int k, int x) {
        PriorityQueue<Integer> queue = new PriorityQueue<>(
            arr.length, (a,b) -> compare(x, a, b));
        Arrays.stream(arr).forEach(queue::add);
        return queue.stream().limit(k).sorted().collect(Collectors.toList());
    }