Search code examples
c++algorithmlistcombinationsdeduplication

C++ Remove duplication in a set of list


I'm trying to remove duplications in the return list in this question

Given a collection of candidate numbers (C) and a target number (T), find all unique combinations in C where the candidate numbers sums to T.

Each number in C may only be used once in the combination.

Note:

  1. All numbers (including target) will be positive integers.

  2. Elements in a combination (a1, a2, … , ak) must be in non-descending order. (ie, a1 ≤ a2 ≤ … ≤ ak).

  3. The solution set must not contain duplicate combinations.

For example, given candidate set 10,1,2,7,6,1,5 and target 8, A solution set is:

[1, 7] 
[1, 2, 5] 
[2, 6] 
[1, 1, 6] 

My question is how to efficiently remove the duplication? The following is my code:

public class Solution {
    public static void main(String[] args) {
        int[] input = { 10, 1, 2, 7, 6, 1, 5 };
        // int[] input = { 2, 1, 1, 4, 4, 2 };
        System.out.println(combinationSum2(input, 8));
    }

    private static class Entry {
        List<Integer> list;
        int target;
        int index; // the previous index

        public Entry(int target) {
            list = new LinkedList<Integer>();
            this.target = target;
        }

        public int add(int num, int index) {
            this.list.add(num);
            this.index = index;
            this.target -= num;
            return target;
        }

        public Entry copy() {
            Entry copy = new Entry(this.target);
            copy.list = new ArrayList<>();
            copy.list.addAll(list);
            copy.target = target;
            copy.index = index;
            return copy;
        }

    }

    public static List<List<Integer>> combinationSum2(int[] input, int target) {
        List<List<Integer>> ret = new LinkedList<List<Integer>>();

        if (null == input || input.length <= 0)
            return ret;

        Arrays.sort(input);

        int N = input.length;
        Queue<Entry> pool = new LinkedList<Entry>();
        for (int i = 0; i < N; i++) {
            if (input[i] <= target) {
                Entry entry = new Entry(target);
                entry.add(input[i], i);
                pool.add(entry);
            }
        }

        while (!pool.isEmpty()) {
            Entry cur = pool.poll();
            if (cur.target == 0) {
                ret.add(cur.list);
            } else if (cur.target > 0) {
                for (int i = cur.index + 1; i < N; i++) {
                    if (cur.target - input[i] >= 0) {
                        Entry copy = cur.copy();
                        copy.add(input[i], i);
                        pool.offer(copy);
                    } else {
                        break;
                    }
                }
            }
        }

        return ret;
    }
}

My first idea is to sort the lists in the return list, them compare one by one to remove duplication. But is there any faster way? or any suggestion?


Solution

  • My suggestion is to use HashSet to prevent adding any existing entry. The first thing to do is override the equals and hashCode function for your Entry class. (more material)

    private static class Entry {
        List<Integer> list;
        int target;
        int index;
        int hash; // <---- add this
    
        public Entry(int target) {
            list = new LinkedList<Integer>();
            this.target = target;
            hash = target;
        }
    
        public int add(int num, int index) {
            this.list.add(num);
            this.index = index;
            this.target -= num;
            hash = hash * 17 + num;
            return target;
        }
    
        public Entry copy() {
            Entry copy = new Entry(this.target);
            copy.list = new ArrayList<>();
            copy.list.addAll(list);
            copy.target = target;
            copy.index = index;
            copy.hash = hash;
            return copy;
        }
    
        @Override
        public boolean equals(Object obj) {
            Entry e = (Entry) obj;
            if ((this.target != e.target) || (this.list.size() != e.list.size())) {
                return false;
            }
            for (int i = 0; i < this.list.size(); i++) {
                if (!this.list.get(i).equals(e.list.get(i)))
                    return false;
            }
            return true;
        }
    
        @Override
        public int hashCode() {
            return hash;
        }
    }
    

    The next step is to use a hashset to filter the result.

    Set<Entry> nodup = new HashSet<Entry>();
    
    while (!pool.isEmpty()) {
        Entry cur = pool.poll();
        if (cur.target == 0) {
            nodup.add(cur);
        } else if (cur.target > 0) {
            // ... your code
        }
    }
    
    for (Entry entry : nodup) {
        ret.add(entry.list);
    }