Search code examples
javaalgorithmcomparatorpriority-queuemax-heap

java - Max heap property not obeyed by PriorityQueue


I was trying out the problem Top K Frequent Elements and I am not asking for how its implemented but what is going wrong in my implementation.

My implementation idea:

  1. Declare a HashMap to count the frequency of each number.
  2. Declare a max heap with a PriorityQueue to keep track of the most frequent K elements.
  3. Add elements to the HashMap.
  4. Use a custom Comparator in the PriorityQueue so that it sorts its elements on the basis of the frequency from the HashMap.

My code is as follows:

public class TopKFrequentElements {
    private final HashMap<Integer, Integer> freqMap = new HashMap<>();
    private final PriorityQueue<Integer> priorityQueue = new PriorityQueue<>(
            (a, b) -> {

                System.out.printf("a = %d b = %d, Freq(%d) = %d Freq(%d) = %d, Freq(%d).compareTo(Freq(%d)) = %d\n", a, b,
                        a, freqMap.get(a), b, freqMap.get(b), b, a, freqMap.get(b).compareTo(freqMap.get(a)));

                return freqMap.get(b).compareTo(freqMap.get(a));

            }
    );

    private void add(int num) {
        freqMap.put(num, freqMap.getOrDefault(num, 0) + 1);
        System.out.println("num = " + num);
        priorityQueue.add(num);
        System.out.println("--------------");
    }

    private int[] getK(int k) {
        int[] res = new int[k];

        int ct = 0;

        while (ct < k) {
            res[ct] = priorityQueue.peek();
            freqMap.put(res[ct], freqMap.get(res[ct]) - 1);
            priorityQueue.poll();

            while (!priorityQueue.isEmpty() &&
                    priorityQueue.peek()==res[ct]) {
                freqMap.put(res[ct], freqMap.get(res[ct]) - 1);
                priorityQueue.poll();
            }

            ct++;
        }
        return res;
    }

    public int[] topKFrequent(int[] nums, int k) {

        for (int num : nums) {
            add(num);
        }
        return getK(k);
    }

    public static void main(String[] args) {
        TopKFrequentElements elements = new TopKFrequentElements();

        System.out.println(Arrays.toString(
                elements.topKFrequent(new int[]{3,0,1,0}, 1))
        );
    }
}

The code works as expected for all test cases expect this really interesting one:

  • {3,0,1,0} k = 1

Here the output is [3] where it should be [0].

I don't understand why this is happening, when the last 0 is added, then the Comparator of the PriorityQueue should preserve the invariant. But it does not.

Digging deeper I printed out when the Comparator is called and I see that it is not comparing with all the elements. Which I am unable to understand why.

num = 3
--------------
num = 0
a = 0 b = 3, Freq(0) = 1 Freq(3) = 1, Freq(3).compareTo(Freq(0)) = 0
--------------
num = 1
a = 1 b = 3, Freq(1) = 1 Freq(3) = 1, Freq(3).compareTo(Freq(1)) = 0
--------------
num = 0
a = 0 b = 0, Freq(0) = 2 Freq(0) = 2, Freq(0).compareTo(Freq(0)) = 0

As you can see the last addition for 0 is just calling the Comparator once when it should check with other values according to my intuition.

Where are the calls for 3 and 1 ?

What am I doing wrong? I also cannot replicate this problem for any other example.

What have I missed about the PriorityQueue implementation?


Solution

  • If the priority of an item is going to change on the fly, you need to remove the item from the queue and add it back, so that the PriorityQueue will reflect the change.

    private void add(int num) {
        freqMap.put(num, freqMap.getOrDefault(num, 0) + 1);
        priorityQueue.remove(num);
        priorityQueue.add(num);
    }
    

    A more efficient way to approach it would be to calculate all the statistics for the priorities in advance, and then populate the priority queue. But you are also altering the priorities with each item you remove from the queue, so while this won't in your case, it would work in a similar case where the priorities are static.

    // accumulate the statistics
    var freqMap = new HashMap<Integer, Integer>();
    for (int num : nums) freqMap.merge(num, 1, Math::addExact);
    
    // let's not refer to the mutable map on the fly
    record PriorityInteger(int value, int priority)
            implements Comparable<PriorityInteger> {
        @Override
        public int compareTo(PriorityInteger that) {
            if (this.priority != that.priority) {
                // sort highest priority values first
                return Integer.compare(that.priority, this.priority);
            }
            // for values of equal priority, sort in ascending order
            return Integer.compare(this.value, that.value);
        }
    }
    
    // create and populate the priority queue
    var pq = new PriorityQueue<PriorityInteger>();
    freqMap.entrySet().stream()
        .map(e -> new PriorityInteger(e.getKey(), e.getValue()))
        .forEach(pq::add);
    
    // pull out the values and display them
    while (!pq.isEmpty()) System.out.println(pq.poll().value());
    

    Example

    For a given set of data

    int[] nums = { 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 6, 6 }
    

    With four 2s they are they highest priority. Three 3s come next. At two apiece, the 1s and 6s are level pegging. The 4 is the lowest priority, with only a single entry.

    Output

    2
    3
    1
    6
    4