Search code examples
javasetsubsetmultiset

Creating all multisets of subsets of NON DISTINCT values, in Java


Given an array of objects, i need to find, as efficiently as possible, all different sets of subsets of the given array, that include all values, when the values in the array may repeat.

for example: if the array is 1, 2, 1, 2 then i need to create the following multisets:

  1. {[1], [1], [2], [2]}
  2. {[1], [1], [2, 2]}
  3. {[1], [2], [1, 2]}
  4. {[1], [1, 2, 2]}
  5. {[1, 1], [2], [2]}
  6. {[1, 1], [2, 2]}
  7. {[1, 2], [1, 2]}
  8. {[1, 1, 2], [2]}
  9. {[1, 1, 2, 2]}

Please note that neither the order of the values inside a subset nor the order of the subsets inside a multiset matters. A multiset like {[1, 2, 2], [1]} is the same as #4, while {[2, 1], [2], [1]} is the same as #3.

The example here was with ints, but in practice i'll have to do it with objects.

This should be as efficient as possible. Best will be to calculate only the right (unrepeating) multisets, without any check if one has already appeared, because the way of creating it will eliminate that from hapenning.

I know how to create all subsets using the binary representation. I used that, combined with recursion, to calculate all multisets. That works perfectly, except it doesn't work when the values repeat. Here is what i did so far:

(a is an array of the given numbers, curr is the current multiset which is being built, and b is the final set of all multisets.)

public static void makeAll(ArrayList<Integer> a, 
                           ArrayList<ArrayList<Integer>> curr,
                           ArrayList<ArrayList<ArrayList<Integer>>> b) {

    ArrayList<ArrayList<Integer>> currCopy;
    ArrayList<Integer> thisGroup, restGroup;
    int currSize = 0, ii = 0;

    if (a.size() == 0)
        b.add(new ArrayList<ArrayList<Integer>>(curr));
    else {
        for (int i = 0; i < 1 << (a.size() - 1); i++) {
            thisGroup = new ArrayList<>();
            restGroup = new ArrayList<>();
            ii = (i << 1) + 1; // the first one is always in, keeps uniquness.

            for (int j = 0; j < a.size(); j++)
                if ((ii & 1 << j) > 0)
                    thisGroup.add(a.get(j));
                else
                    restGroup.add(a.get(j));

            currSize = curr.size();
            curr.add(new ArrayList<Integer>(thisGroup));

            makeAll(restGroup, curr, b);

            curr.subList(currSize, curr.size()).clear();
        }
    }
}

Thanks in advance!


Solution

  • This is the "power set" problem, with more than the two usual sets (which is usually either being included in or excluded from a subset, hence the power set has 2^N elements). For your version of this problem, any element could be part of any one of up to N subsets, so the problem scales as N^N (which gets big very quickly).

    To find all the unique partitionings into this "N-way power set" given a list of N elements, you need to conceptually generate all N-digit numbers in base N, then the value of each digit of a number gives the partition index for the corresponding element in the input (meaning usually you will end up with empty partitions, for all cases except when the number of partitions equals N). Use these digit indices to group elements into lists of elements that share the same index, producing a list of lists. For detecting duplicates, you have to sort the sublists and then sort the list of lists, then add the sorted list of lists to a set. You can't avoid this last deduplication step, since your description allows duplicated elements in the input.

    package main;
    
    import java.util.ArrayList;
    import java.util.Collections;
    import java.util.HashSet;
    import java.util.List;
    import java.util.Set;
    
    public class PrintPartitionings {
        /** A list of integers */
        public static class Partition extends ArrayList<Integer>
                implements Comparable<Partition> {
            // Lexicographic comparator
            @Override
            public int compareTo(Partition other) {
                for (int i = 0, ii = Math.min(this.size(),
                        other.size()); i < ii; i++) {
                    int c = this.get(i).compareTo(other.get(i));
                    if (c != 0) {
                        return c;
                    }
                }
                return Integer.compare(this.size(), other.size());
            }
        }
    
        /** A list of lists of integers */
        public static class Partitioning extends ArrayList<Partition>
                implements Comparable<Partitioning> {
            public Partitioning() {
                super();
            }
    
            public Partitioning(int N) {
                super(N);
                // Pre-allocate sub-lists for convenience
                for (int j = 0; j < N; j++) {
                    add(new Partition());
                }
            }
    
            // Lexicographic comparator
            @Override
            public int compareTo(Partitioning other) {
                for (int i = 0, ii = Math.min(this.size(),
                        other.size()); i < ii; i++) {
                    int c = this.get(i).compareTo(other.get(i));
                    if (c != 0) {
                        return c;
                    }
                }
                return Integer.compare(this.size(), other.size());
            }
        }
    
        /** Print all unique partitionings of the passed array of integers */
        public static void printPartitionings(int[] elts) {
            int N = elts.length;
            Set<Partitioning> setOfPartitionings = new HashSet<>();
            // Generate integers in [0, N^N)
            for (long i = 0, ii = (long) Math.pow(N, N); i < ii; i++) {
                // Create empty partitioning
                Partitioning partitioning = new Partitioning(N);
    
                // Assign each element to a partition based on base N digit
                long digits = i;
                for (int j = 0; j < N; j++) {
                    int digit = (int) (digits % N);
                    digits /= N;
                    partitioning.get(digit).add(elts[j]);
                }
    
                // Sort individual partitions, and remove empty partitions
                Partitioning partitioningSorted = new Partitioning();
                for (Partition partition : partitioning) {
                    if (!partition.isEmpty()) {
                        Collections.sort(partition);
                        partitioningSorted.add(partition);
                    }
                }
    
                // Sort the partitioning
                Collections.sort(partitioningSorted);
    
                // Add the result to the final set of partitionings
                setOfPartitionings.add(partitioningSorted);
            }
    
            // Sort lexicographically to make it easier to view the result
            List<Partitioning> setOfPartitioningsSorted = new ArrayList<>(
                    setOfPartitionings);
            Collections.sort(setOfPartitioningsSorted);
            for (Partitioning partitioning : setOfPartitioningsSorted) {
                System.out.println(partitioning);
            }
        }
    
        public static void main(String[] args) {
            printPartitionings(new int[] { 1, 2, 1, 2 });
        }
    }
    

    The implementation is not particularly optimized -- this could be made faster in a number of ways. Also, the code as shown will only work for moderately sized problems, when N^N < Long.MAX_VALUE, i.e. the max value of N is 15 for the code as shown (but you won't want to run problems of that size anyway, since it will take forever for the code to run).

    Output:

    [[1], [1], [2], [2]]
    [[1], [1], [2, 2]]
    [[1], [1, 2], [2]]
    [[1], [1, 2, 2]]
    [[1, 1], [2], [2]]
    [[1, 1], [2, 2]]
    [[1, 1, 2], [2]]
    [[1, 1, 2, 2]]
    [[1, 2], [1, 2]]
    

    For input { 1, 2, 3, 2 } the output is:

    [[1], [2], [2], [3]]
    [[1], [2], [2, 3]]
    [[1], [2, 2], [3]]
    [[1], [2, 2, 3]]
    [[1, 2], [2], [3]]
    [[1, 2], [2, 3]]
    [[1, 2, 2], [3]]
    [[1, 2, 2, 3]]
    [[1, 2, 3], [2]]
    [[1, 3], [2], [2]]
    [[1, 3], [2, 2]]