Search code examples
pythonmergetreeset-intersectionequivalence-classes

Python: simple list merging based on intersections


Consider there are some lists of integers as:

#--------------------------------------
0 [0,1,3]
1 [1,0,3,4,5,10,...]
2 [2,8]
3 [3,1,0,...]
...
n []
#--------------------------------------

The question is to merge lists having at least one common element. So the results only for the given part will be as follows:

#--------------------------------------
0 [0,1,3,4,5,10,...]
2 [2,8]
#--------------------------------------

What is the most efficient way to do this on large data (elements are just numbers)? Is tree structure something to think about? I do the job now by converting lists to sets and iterating for intersections, but it is slow! Furthermore I have a feeling that is so-elementary! In addition, the implementation lacks something (unknown) because some lists remain unmerged sometime! Having said that, if you were proposing self-implementation please be generous and provide a simple sample code [apparently Python is my favoriate :)] or pesudo-code.
Update 1: Here is the code I was using:

#--------------------------------------
lsts = [[0,1,3],
        [1,0,3,4,5,10,11],
        [2,8],
        [3,1,0,16]];
#--------------------------------------

The function is (buggy!!):

#--------------------------------------
def merge(lsts):
    sts = [set(l) for l in lsts]
    i = 0
    while i < len(sts):
        j = i+1
        while j < len(sts):
            if len(sts[i].intersection(sts[j])) > 0:
                sts[i] = sts[i].union(sts[j])
                sts.pop(j)
            else: j += 1                        #---corrected
        i += 1
    lst = [list(s) for s in sts]
    return lst
#--------------------------------------

The result is:

#--------------------------------------
>>> merge(lsts)
>>> [0, 1, 3, 4, 5, 10, 11, 16], [8, 2]]
#--------------------------------------

Update 2: To my experience the code given by Niklas Baumstark below showed to be a bit faster for the simple cases. Not tested the method given by "Hooked" yet, since it is completely different approach (by the way it seems interesting). The testing procedure for all of these could be really hard or impossible to be ensured of the results. The real data set I will use is so large and complex, so it is impossible to trace any error just by repeating. That is I need to be 100% satisfied of the reliability of the method before pushing it in its place within a large code as a module. Simply for now Niklas's method is faster and the answer for simple sets is correct of course.
However how can I be sure that it works well for real large data set? Since I will not be able to trace the errors visually!

Update 3: Note that reliability of the method is much more important than speed for this problem. I will be hopefully able to translate the Python code to Fortran for the maximum performance finally.

Update 4:
There are many interesting points in this post and generously given answers, constructive comments. I would recommend reading all thoroughly. Please accept my appreciation for the development of the question, amazing answers and constructive comments and discussion.


Solution

  • My attempt:

    def merge(lsts):
        sets = [set(lst) for lst in lsts if lst]
        merged = True
        while merged:
            merged = False
            results = []
            while sets:
                common, rest = sets[0], sets[1:]
                sets = []
                for x in rest:
                    if x.isdisjoint(common):
                        sets.append(x)
                    else:
                        merged = True
                        common |= x
                results.append(common)
            sets = results
        return sets
    
    lst = [[65, 17, 5, 30, 79, 56, 48, 62],
           [6, 97, 32, 93, 55, 14, 70, 32],
           [75, 37, 83, 34, 9, 19, 14, 64],
           [43, 71],
           [],
           [89, 49, 1, 30, 28, 3, 63],
           [35, 21, 68, 94, 57, 94, 9, 3],
           [16],
           [29, 9, 97, 43],
           [17, 63, 24]]
    print merge(lst)
    

    Benchmark:

    import random
    
    # adapt parameters to your own usage scenario
    class_count = 50
    class_size = 1000
    list_count_per_class = 100
    large_list_sizes = list(range(100, 1000))
    small_list_sizes = list(range(0, 100))
    large_list_probability = 0.5
    
    if False:  # change to true to generate the test data file (takes a while)
        with open("/tmp/test.txt", "w") as f:
            lists = []
            classes = [
                range(class_size * i, class_size * (i + 1)) for i in range(class_count)
            ]
            for c in classes:
                # distribute each class across ~300 lists
                for i in xrange(list_count_per_class):
                    lst = []
                    if random.random() < large_list_probability:
                        size = random.choice(large_list_sizes)
                    else:
                        size = random.choice(small_list_sizes)
                    nums = set(c)
                    for j in xrange(size):
                        x = random.choice(list(nums))
                        lst.append(x)
                        nums.remove(x)
                    random.shuffle(lst)
                    lists.append(lst)
            random.shuffle(lists)
            for lst in lists:
                f.write(" ".join(str(x) for x in lst) + "\n")
    
    setup = """
    # Niklas'
    def merge_niklas(lsts):
        sets = [set(lst) for lst in lsts if lst]
        merged = 1
        while merged:
            merged = 0
            results = []
            while sets:
                common, rest = sets[0], sets[1:]
                sets = []
                for x in rest:
                    if x.isdisjoint(common):
                        sets.append(x)
                    else:
                        merged = 1
                        common |= x
                results.append(common)
            sets = results
        return sets
    
    # Rik's
    def merge_rik(data):
        sets = (set(e) for e in data if e)
        results = [next(sets)]
        for e_set in sets:
            to_update = []
            for i, res in enumerate(results):
                if not e_set.isdisjoint(res):
                    to_update.insert(0, i)
    
            if not to_update:
                results.append(e_set)
            else:
                last = results[to_update.pop(-1)]
                for i in to_update:
                    last |= results[i]
                    del results[i]
                last |= e_set
        return results
    
    # katrielalex's
    def pairs(lst):
        i = iter(lst)
        first = prev = item = i.next()
        for item in i:
            yield prev, item
            prev = item
        yield item, first
    
    import networkx
    
    def merge_katrielalex(lsts):
        g = networkx.Graph()
        for lst in lsts:
            for edge in pairs(lst):
                g.add_edge(*edge)
        return networkx.connected_components(g)
    
    # agf's (optimized)
    from collections import deque
    
    def merge_agf_optimized(lists):
        sets = deque(set(lst) for lst in lists if lst)
        results = []
        disjoint = 0
        current = sets.pop()
        while True:
            merged = False
            newsets = deque()
            for _ in xrange(disjoint, len(sets)):
                this = sets.pop()
                if not current.isdisjoint(this):
                    current.update(this)
                    merged = True
                    disjoint = 0
                else:
                    newsets.append(this)
                    disjoint += 1
            if sets:
                newsets.extendleft(sets)
            if not merged:
                results.append(current)
                try:
                    current = newsets.pop()
                except IndexError:
                    break
                disjoint = 0
            sets = newsets
        return results
    
    # agf's (simple)
    def merge_agf_simple(lists):
        newsets, sets = [set(lst) for lst in lists if lst], []
        while len(sets) != len(newsets):
            sets, newsets = newsets, []
            for aset in sets:
                for eachset in newsets:
                    if not aset.isdisjoint(eachset):
                        eachset.update(aset)
                        break
                else:
                    newsets.append(aset)
        return newsets
    
    # alexis'
    def merge_alexis(data):
        bins = range(len(data))  # Initialize each bin[n] == n
        nums = dict()
    
        data = [set(m) for m in data]  # Convert to sets
        for r, row in enumerate(data):
            for num in row:
                if num not in nums:
                    # New number: tag it with a pointer to this row's bin
                    nums[num] = r
                    continue
                else:
                    dest = locatebin(bins, nums[num])
                    if dest == r:
                        continue  # already in the same bin
    
                    if dest > r:
                        dest, r = r, dest  # always merge into the smallest bin
    
                    data[dest].update(data[r])
                    data[r] = None
                    # Update our indices to reflect the move
                    bins[r] = dest
                    r = dest
    
        # Filter out the empty bins
        have = [m for m in data if m]
        return have
    
    def locatebin(bins, n):
        while bins[n] != n:
            n = bins[n]
        return n
    
    lsts = []
    size = 0
    num = 0
    max = 0
    for line in open("/tmp/test.txt", "r"):
        lst = [int(x) for x in line.split()]
        size += len(lst)
        if len(lst) > max:
            max = len(lst)
        num += 1
        lsts.append(lst)
    """
    
    setup += """
    print "%i lists, {class_count} equally distributed classes, average size %i, max size %i" % (num, size/num, max)
    """.format(class_count=class_count)
    
    import timeit
    print "niklas"
    print timeit.timeit("merge_niklas(lsts)", setup=setup, number=3)
    print "rik"
    print timeit.timeit("merge_rik(lsts)", setup=setup, number=3)
    print "katrielalex"
    print timeit.timeit("merge_katrielalex(lsts)", setup=setup, number=3)
    print "agf (1)"
    print timeit.timeit("merge_agf_optimized(lsts)", setup=setup, number=3)
    print "agf (2)"
    print timeit.timeit("merge_agf_simple(lsts)", setup=setup, number=3)
    print "alexis"
    print timeit.timeit("merge_alexis(lsts)", setup=setup, number=3)
    

    These timings are obviously dependent on the specific parameters to the benchmark, like number of classes, number of lists, list size, etc. Adapt those parameters to your need to get more helpful results.

    Below are some example outputs on my machine for different parameters. They show that all the algorithms have their strength and weaknesses, depending on the kind of input they get:

    =====================
    # many disjoint classes, large lists
    class_count = 50
    class_size = 1000
    list_count_per_class = 100
    large_list_sizes = list(range(100, 1000))
    small_list_sizes = list(range(0, 100))
    large_list_probability = 0.5
    =====================
    
    niklas
    5000 lists, 50 equally distributed classes, average size 298, max size 999
    4.80084705353
    rik
    5000 lists, 50 equally distributed classes, average size 298, max size 999
    9.49251699448
    katrielalex
    5000 lists, 50 equally distributed classes, average size 298, max size 999
    21.5317108631
    agf (1)
    5000 lists, 50 equally distributed classes, average size 298, max size 999
    8.61671280861
    agf (2)
    5000 lists, 50 equally distributed classes, average size 298, max size 999
    5.18117713928
    => alexis
    => 5000 lists, 50 equally distributed classes, average size 298, max size 999
    => 3.73504281044
    
    ===================
    # less number of classes, large lists
    class_count = 15
    class_size = 1000
    list_count_per_class = 300
    large_list_sizes = list(range(100, 1000))
    small_list_sizes = list(range(0, 100))
    large_list_probability = 0.5
    ===================
    
    niklas
    4500 lists, 15 equally distributed classes, average size 296, max size 999
    1.79993700981
    rik
    4500 lists, 15 equally distributed classes, average size 296, max size 999
    2.58237695694
    katrielalex
    4500 lists, 15 equally distributed classes, average size 296, max size 999
    19.5465381145
    agf (1)
    4500 lists, 15 equally distributed classes, average size 296, max size 999
    2.75445604324
    => agf (2)
    => 4500 lists, 15 equally distributed classes, average size 296, max size 999
    => 1.77850699425
    alexis
    4500 lists, 15 equally distributed classes, average size 296, max size 999
    3.23530197144
    
    ===================
    # less number of classes, smaller lists
    class_count = 15
    class_size = 1000
    list_count_per_class = 300
    large_list_sizes = list(range(100, 1000))
    small_list_sizes = list(range(0, 100))
    large_list_probability = 0.1
    ===================
    
    niklas
    4500 lists, 15 equally distributed classes, average size 95, max size 997
    0.773697137833
    rik
    4500 lists, 15 equally distributed classes, average size 95, max size 997
    1.0523750782
    katrielalex
    4500 lists, 15 equally distributed classes, average size 95, max size 997
    6.04466891289
    agf (1)
    4500 lists, 15 equally distributed classes, average size 95, max size 997
    1.20285701752
    => agf (2)
    => 4500 lists, 15 equally distributed classes, average size 95, max size 997
    => 0.714507102966
    alexis
    4500 lists, 15 equally distributed classes, average size 95, max size 997
    1.1286110878