I have an array of integers and need to find the K most frequent elements efficiently. The solution should ideally work in O(N log K) time complexity using Java Collections. Sorting all elements takes O(N log N) time, which is inefficient when N
is large. I am looking for a more optimized approach.
What I Tried and What I Expected:
I first used a HashMap<Integer, Integer> to count the frequency of each element. Then, I sorted the entries using Collections.sort(), but this approach runs in O(N log N) time, which is not optimal. I explored using PriorityQueue (MinHeap) or TreeMap, but I’m struggling with implementing them correctly. Specifically, I am unsure how to efficiently maintain the K most frequent elements while iterating through the frequency map.
What I Tried(code):
import java.util.*;
public class KMostFrequent {
public static List<Integer> topKFrequent(int[] nums, int k) {
Map<Integer, Integer> frequencyMap = new HashMap<>();
for (int num : nums) {
frequencyMap.put(num, frequencyMap.getOrDefault(num, 0) + 1);
}
List<Map.Entry<Integer, Integer>> entryList = new ArrayList<>(frequencyMap.entrySet());
entryList.sort((a, b) -> b.getValue() - a.getValue()); // Sorting in descending order
List<Integer> result = new ArrayList<>();
for (int i = 0; i < k; i++) {
result.add(entryList.get(i).getKey());
}
return result;
}
public static void main(String[] args) {
int[] nums = {1, 1, 1, 2, 2, 3};
int k = 2;
System.out.println(topKFrequent(nums, k)); // Expected Output: [1, 2]
}
}
Issue with This Code:
The sorting step (entryList.sort()
) takes O(N log N), which is inefficient for large N
. Instead of sorting, I believe using a MinHeap (PriorityQueue) or TreeMap can improve performance.
Additionally, I have searched for existing solutions but found mostly Collections.sort()
-based approaches, which do not meet the desired efficiency. Some posts suggest using a MinHeap, but I am struggling with maintaining the correct heap structure while iterating over the map.
We can put the frequencies from your frequency map into a heap and never keep more than k
elements in it. The following method does that. Details of how it works are in the comments.
import java.util.*;
public class KMostFrequent {
public static List<Integer> topKFrequent(int[] nums, int k) {
// Step 1: Count the frequency of each element
Map<Integer, Integer> frequencyMap = new HashMap<>();
for (int num : nums) {
frequencyMap.put(num, frequencyMap.getOrDefault(num, 0) + 1);
}
// Step 2: Use a MinHeap (PriorityQueue) to keep the top K elements
PriorityQueue<Map.Entry<Integer, Integer>> minHeap = new PriorityQueue<>(
Comparator.comparingInt(Map.Entry::getValue) // MinHeap based on frequency
);
// Step 3: Iterate over the frequency map and maintain only the top K elements
for (Map.Entry<Integer, Integer> entry : frequencyMap.entrySet()) {
minHeap.offer(entry);
if (minHeap.size() > k) {
minHeap.poll(); // Remove the least frequent element
}
}
// Step 4: Extract elements from the MinHeap
List<Integer> result = new ArrayList<>();
while (!minHeap.isEmpty()) {
result.add(minHeap.poll().getKey());
}
Collections.reverse(result); // Optional: to maintain descending order of frequency
return result;
}
public static void main(String[] args) {
int[] nums = {1, 1, 1, 2, 2, 3};
int k = 2;
System.out.println(topKFrequent(nums, k)); // Expected Output: [1, 2]
}
}