Search code examples
javaperformanceoptimizationmemory-managementunion-find

How to optimize a Java Union-Find program to avoid OutOfMemoryError when processing large datasets


This is a follow-up to my earlier question

I've managed to implement a working solution:

package com.test;

import java.io.*;
import java.util.*;

public class LineGroupProcessor {

    private LineGroupProcessor() {
    }

    public static void main(String[] args) {
        validateArgs(args);
        List<String[]> validRows = readValidRows(args[0]);
        UnionFind unionFind = new UnionFind(validRows.size());
        Map<String, Integer> columnValueMap = new HashMap<>();
        for (int i = 0; i < validRows.size(); i++) {
            processRow(validRows, columnValueMap, unionFind, i);
        }
        writeOutput(groupAndSortRows(validRows, unionFind));
    }

    private static void validateArgs(String[] args) {
        if (args.length == 0) {
            throw new IllegalArgumentException("No input file provided. Please specify a text or CSV file.");
        }

        String filePath = args[0];
        if (!filePath.endsWith(".txt") && !filePath.endsWith(".csv")) {
            throw new IllegalArgumentException("Invalid file type. Please provide a text or CSV file.");
        }

        File file = new File(filePath);
        if (!file.exists() || !file.isFile()) {
            throw new IllegalArgumentException("File does not exist or is not a valid file: " + filePath);
        }
    }

    private static List<String[]> readValidRows(String filePath) {
        List<String[]> rows = new ArrayList<>();
        try (BufferedReader br = new BufferedReader(new FileReader(filePath))) {
            String line;
            while ((line = br.readLine()) != null) {
                String[] columns = line.split(";");
                if (isValidRow(columns)) {
                    rows.add(columns);
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
        return rows;
    }

    private static boolean isValidRow(String[] columns) {
        for (String column : columns) {
            if (column.isEmpty() && !column.matches("^\"\\d{11}\"$")) {
                return false;
            }
        }
        return true;
    }

    private static void processRow(List<String[]> rows, Map<String, Integer> columnValueMap, UnionFind uf, int rowIndex) {
        String[] row = rows.get(rowIndex);
        for (int j = 0; j < row.length; j++) {
            String value = row[j].trim();
            if (!value.isEmpty() && !value.equals("\"\"")) {
                StringBuilder keyBuilder = new StringBuilder();
                keyBuilder.append(j).append(",").append(value);
                String key = keyBuilder.toString();
                if (columnValueMap.containsKey(key)) {
                    int prevRowIdx = columnValueMap.get(key);
                    uf.union(rowIndex, prevRowIdx);
                } else {
                    columnValueMap.put(key, rowIndex);
                }
            }
        }
    }

    private static List<Set<String>> groupAndSortRows(List<String[]> rows, UnionFind uf) {
        Map<Integer, Set<String>> groups = new HashMap<>();
        for (int i = 0; i < rows.size(); i++) {
            int group = uf.find(i);
            groups.computeIfAbsent(group, k -> new HashSet<>()).add(Arrays.toString(rows.get(i)));
        }

        List<Set<String>> sortedGroups = new ArrayList<>(groups.values());
        sortedGroups.sort((g1, g2) -> Integer.compare(g2.size(), g1.size()));
        return sortedGroups;
    }

    private static void writeOutput(List<Set<String>> sortedGroups) {
        long groupsWithMoreThanOneRow = sortedGroups.stream().filter(group -> group.size() > 1).count();
        try (PrintWriter writer = new PrintWriter("output.txt")) {
            writer.println("Total number of groups with more than one element: " + groupsWithMoreThanOneRow);
            writer.println();
            int groupNumber = 1;
            for (Set<String> group : sortedGroups) {
                writer.println("Group " + groupNumber);
                for (String row : group) {
                    writer.println(row);
                }
                writer.println();
                groupNumber++;
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }
}

package com.test;

public class UnionFind {

    private final int[] parent;
    private final int[] rank;

    public UnionFind(int size) {
        parent = new int[size];
        rank = new int[size];
        for (int i = 0; i < size; i++) {
            parent[i] = i;
            rank[i] = 0;
        }
    }

    public int find(int index) {
        if (parent[index] != index) {
            parent[index] = find(parent[index]);
        }
        return parent[index];
    }

    public void union(int index1, int index2) {
        int element1 = find(index1);
        int element2 = find(index2);
        if (element1 != element2) {
            if (rank[element1] > rank[element2]) {
                parent[element2] = element1;
            } else if (rank[element1] < rank[element2]) {
                parent[element1] = element2;
            } else {
                parent[element2] = element1;
                rank[element1]++;
            }
        }
    }
}

The program has specific requirements: it should complete within 30 seconds and use a maximum of 1GB of memory (-Xmx1G).

When running test datasets of 1 million and 10 million rows, I get the following errors:

> Task :com.test.LineGroupProcessor.main()
Exception in thread "main" java.lang.OutOfMemoryError: Java heap space
    at com.test.LineGroupProcessor.lambda$groupAndSortRows$0(LineGroupProcessor.java:85)
    at com.test.LineGroupProcessor$$Lambda/0x000002779d000400.apply(Unknown Source)
    at java.base/java.util.HashMap.computeIfAbsent(HashMap.java:1228)
    at com.test.LineGroupProcessor.groupAndSortRows(LineGroupProcessor.java:85)
    at com.test.LineGroupProcessor.main(LineGroupProcessor.java:19)
    
> Task :com.test.LineGroupProcessor.main()
Exception in thread "main" java.lang.OutOfMemoryError: Java heap space: failed reallocation of scalar replaced objects
    at java.base/java.util.HashMap.computeIfAbsent(HashMap.java:1222)
    at com.test.LineGroupProcessor.groupAndSortRows(LineGroupProcessor.java:85)
    at com.test.LineGroupProcessor.main(LineGroupProcessor.java:19) 

How can I optimize the code to stay within the 1GB memory limit?


Solution

  • Your current approach (roughly)

    1. read the entire file into memory into List<String[]> – so if it's 1 million lines, you'll get a List of 1 million elements. Also, instead of storing each line of text as a String like "200";"123";"100", you're creating three (3) separate Strings ("200", "123", and "100") and an array.
    2. iterate through the List to construct an instance of UnionFind
    3. iterate once more through the List to build a HashMap named "groups" which includes a copy of the line input (the original line among the 1 million lines from the file)
    4. some sorting logic, then final pass through sorted data to print things out

    A few observations

    • Step 1 does not appear to require that you read the entire file all at once. If the file is large enough (coupled with your runtime memory constraints), you may not be able to proceed beyond this step.
    • Step 1 also doesn't appear to require that you parse the input line (a single String like "200";"123";"100") and instead store the parsed result (as String[]).
    • Step 2 seems ok, though you may have room to rework this part while chasing further optimizations. I didn't look too closely here.
    • Step 3 is potentially another opportunity to process the data line by line.

    Suggestions

    • Changes to step 1
      1. Do not read every line (1 million or 10 million) all at once
        • Instead, read one line of data, then process that line of data by itself. In your posted code, this looks doable: incremently update your UnionFind object with one input line at a time.
        • Why? It will avoid loading the entire file contents into memory (along with object overhead of Strings and arrays which you're using to represent the data)
      2. Update your code to remove quotes in the data
        • processRow() keeps the quoted values of each line, so "111" instead of simply 111. Perhaps you need to retain the quotes in the data (?) but if not you could reduce the size of each String somewhat by removing the leading and trailing " characters.
        • Why? You'll use less memory. For a 10 million line file, that's 20-60 million fewer characters (assuming 1 to 3 occurrences of data like "123" per line).
    • Change step 3
      • Rework this part of the code so that it processes one line at a time.
      • I didn't read the code closely enough to follow the intent with the grouping logic, but I would look for a way to rewrite the logic in terms of the line number or some other lightweight representation of the data. For example, the posted code does groups.computeIfAbsent(group, k -> new HashSet<>()).add(Arrays.toString(rows.get(i))); which uses the String[] version of line number "i" and builds a new String out of that (Arrays.toString()). It looks like the intent is capture that line number "i" is important, and it's not required to store a string copy of the contents of line "i". If you could instead capture Set<Integer> instead of Set<String> you'll use less memory.
      • Something along these lines ought to help: store only what's necessary to group the data.
    • Change step 4
      • Assuming you implement changes to Step 3 – such that you're storing only the line number of each line in the "groups" data – you would need to rework the sorting logic.
      • As earlier, you could read the file data one line at a time (don't read the whole thing), and use your "group line number" data to decide which line to read (and print corresponding output). Depending on your output needs, perhaps there is room for this step to be simplified so that you don't exhaustively print each element. If you require printing each raw line of input, then you could use your "group" data (which contains line numbers, not raw lines of data) and use that to hop around the file to get whatever line should be printed next.
      • This might perform poorly (lots of disk I/O) but I would try it and make actual observations rather than speculate. If it does perform poorly, you could revisit Step 1 and reintroduce storing the raw data, but: only storing each line of input as raw data. So step 1 could become: create an array of raw lines, then parse into String[] only when processing each line. I would still try stripping out the " characters from the raw input, so that you're storing one raw line like 200;123;100 instead of "200";"123";"100".
    • Explore compressing the data
      • If the above suggestions don't work for whatever reason, you could explore reading the entire file into memory as compressed data. This may or may not be beneficial with your actual data (some data doesn't compress all that well).