Search code examples
javaalgorithmperformancefile-processingunion-find

How to Correctly Group Rows by Column Values Using Union-Find in Java?


I have a task that I can't seem to solve. I am given a file with a million rows. Each row may contain an unlimited number of elements of the following type:

"111";"123";"222"
"200";"123";"100"
"300";"";"100"

Invalid rows (to be ignored) are formatted as:

"8383"200000741652251"
"79855053897"83100000580443402";"200000133000191"

The goal is to group rows based on the following criteria: If two rows have a match in one or more non-empty columns, they should belong to the same group. For example:

"111";"123";"222"
"200";"123";"100"
"300";"";"100"

These all belong to the same group because the first two rows have the same value 123 in the second column, and the last two rows have the same value 100 in the third column.

However, rows like:

"100";"200";"300"
"200";"300";"100"

should not be in the same group because they don't meet the column-matching criteria.

The program should complete in 30 seconds and must work within 1GB of memory (-Xmx1G).

AI assistant suggested using the "union-find" or "disjoint-set" algorithm and provided some initial code. After making some modifications, here's my current implementation:

package com.test;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.*;

public class UniqueLineGrouper {
    static class UnionFind {
        private int[] parent;
        private int[] rank;

        public UnionFind(int size) {
            parent = new int[size];
            rank = new int[size];
            for (int i = 0; i < size; i++) {
                parent[i] = i;
                rank[i] = 0;
            }
        }

        public int find(int x) {
            if (parent[x] != x) {
                parent[x] = find(parent[x]);
            }
            return parent[x];
        }

        public void union(int x, int y) {
            int rootX = find(x);
            int rootY = find(y);
            if (rootX != rootY) {
                if (rank[rootX] > rank[rootY]) {
                    parent[rootY] = rootX;
                } else if (rank[rootX] < rank[rootY]) {
                    parent[rootX] = rootY;
                } else {
                    parent[rootY] = rootX;
                    rank[rootX]++;
                }
            }
        }
    }

    public static void main(String[] args) {
        List<String[]> rows = new ArrayList<>();
        try (BufferedReader br = new BufferedReader(new FileReader(args[0]))) {
            String line;
            while ((line = br.readLine()) != null) {
                String[] columns = line.split(";");
                boolean isValid = true;
                for (String column : columns) {
                    if (column.isEmpty() && !column.matches("\\d{3}")) {
                        isValid = false;
                        break;
                    }
                }
                if (isValid) {
                    rows.add(columns);
                }
            }
        } catch (IOException e) {
            e.printStackTrace();
        }

//This part has a bug
        UnionFind uf = new UnionFind(rows.size());
        Map<String, Integer> columnValueMap = new HashMap<>();
        for (int i = 0; i < rows.size(); i++) {
            String[] row = rows.get(i);
            for (int j = 0; j < row.length; j++) {
                String value = row[j].trim();
                if (!value.isEmpty() && !value.equals("\"\"")) {
                    if (columnValueMap.containsKey(value)) {
                        int prevRowIdx = columnValueMap.get(value);
                        uf.union(i, prevRowIdx);
                    } else {
                        columnValueMap.put(value, i);
                    }
                }
            }
        }

        Map<Integer, List<Integer>> groups = new HashMap<>();
        for (int i = 0; i < rows.size(); i++) {
            int group = uf.find(i);
            groups.computeIfAbsent(group, k -> new ArrayList<>()).add(i);
        }

        for (List<Integer> group : groups.values()) {
            System.out.println("Group:");
            for (int idx : group) {
                System.out.println(Arrays.toString(rows.get(idx)));
            }
            System.out.println();
        }
    }
}

The grouping works, but there’s a bug where rows are grouped even if the match is across different columns. For example, the following set of rows gets grouped together, even though the last two rows shouldn’t be in the same group as per the column-matching rule:

"111";"123";"222"
"200";"123";"100"
"300";"";"100"
"100";"200";"300"
"200";"300";"100"

Also I was thinking of using Collections.disjoint() to skip rows that don’t share any elements but I’m not sure if this would improve performance.

Can anyone help me resolve this issue?


Solution

  • Union-find is indeed the best way to solve this problem, but not the way you're using it.

    You should initially make a set for each unique (column,value) pair that occurs in the problem. For example, if you have a row like "300";"";"100", you should make sets for (0,300) and (2,100).

    Then, for each row, merge the sets for all the pairs that occur in the row.

    All of the values for each row will then be in the same group -- that's the row's group -- and all rows with matching values will be in the same group.