Search code examples
javasortingsearchcollectionshashmap

How to find K most frequent elements using java collections framework?


I have an array of integers and need to find the K most frequent elements efficiently. The solution should ideally work in O(N log K) time complexity using Java Collections. Sorting all elements takes O(N log N) time, which is inefficient when N is large. I am looking for a more optimized approach.

What I Tried and What I Expected:

I first used a HashMap<Integer, Integer> to count the frequency of each element. Then, I sorted the entries using Collections.sort(), but this approach runs in O(N log N) time, which is not optimal. I explored using PriorityQueue (MinHeap) or TreeMap, but I’m struggling with implementing them correctly. Specifically, I am unsure how to efficiently maintain the K most frequent elements while iterating through the frequency map.

What I Tried(code):

import java.util.*;

public class KMostFrequent {
    public static List<Integer> topKFrequent(int[] nums, int k) {
        Map<Integer, Integer> frequencyMap = new HashMap<>();
        for (int num : nums) {
            frequencyMap.put(num, frequencyMap.getOrDefault(num, 0) + 1);
        }

        List<Map.Entry<Integer, Integer>> entryList = new ArrayList<>(frequencyMap.entrySet());
        entryList.sort((a, b) -> b.getValue() - a.getValue());  // Sorting in descending order

        List<Integer> result = new ArrayList<>();
        for (int i = 0; i < k; i++) {
            result.add(entryList.get(i).getKey());
        }
        return result;
    }

    public static void main(String[] args) {
        int[] nums = {1, 1, 1, 2, 2, 3};
        int k = 2;
        System.out.println(topKFrequent(nums, k));  // Expected Output: [1, 2]
    }
}

Issue with This Code:

The sorting step (entryList.sort()) takes O(N log N), which is inefficient for large N. Instead of sorting, I believe using a MinHeap (PriorityQueue) or TreeMap can improve performance.

Additionally, I have searched for existing solutions but found mostly Collections.sort()-based approaches, which do not meet the desired efficiency. Some posts suggest using a MinHeap, but I am struggling with maintaining the correct heap structure while iterating over the map.


Solution

  • We can put the frequencies from your frequency map into a heap and never keep more than k elements in it. The following method does that. Details of how it works are in the comments.

    import java.util.*;
    
    public class KMostFrequent {
        public static List<Integer> topKFrequent(int[] nums, int k) {
            // Step 1: Count the frequency of each element
            Map<Integer, Integer> frequencyMap = new HashMap<>();
            for (int num : nums) {
                frequencyMap.put(num, frequencyMap.getOrDefault(num, 0) + 1);
            }
    
            // Step 2: Use a MinHeap (PriorityQueue) to keep the top K elements
            PriorityQueue<Map.Entry<Integer, Integer>> minHeap = new PriorityQueue<>(
                Comparator.comparingInt(Map.Entry::getValue) // MinHeap based on frequency
            );
    
            // Step 3: Iterate over the frequency map and maintain only the top K elements
            for (Map.Entry<Integer, Integer> entry : frequencyMap.entrySet()) {
                minHeap.offer(entry);
                if (minHeap.size() > k) {
                    minHeap.poll(); // Remove the least frequent element
                }
            }
    
            // Step 4: Extract elements from the MinHeap
            List<Integer> result = new ArrayList<>();
            while (!minHeap.isEmpty()) {
                result.add(minHeap.poll().getKey());
            }
            Collections.reverse(result); // Optional: to maintain descending order of frequency
    
            return result;
        }
    
        public static void main(String[] args) {
            int[] nums = {1, 1, 1, 2, 2, 3};
            int k = 2;
            System.out.println(topKFrequent(nums, k)); // Expected Output: [1, 2]
        }
    }