Search code examples
pythonalgorithmhashtablebig-o

4SUM variation in quadratic complexity (Python 3.5)


I've been having a little bit of problem with a variation of the 4SUM problem. Essentially, we are required to choose one number, say i, j, k, l from 4 tuples of 500 integers each, say a, b, c, d, such that i+j+k+l == 0. These integers range from -20000 to 20000. Currently, my code has a time complexity of O(n3) I believe. According to my instructor, this time complexity can be reduced even further, to either O(n2 log n) or O(n2) (by using hash tables or something).

Unfortunately, if this hash tables method is the way to go, I am not sure how to go about implementing it. Hence if anyone could show me how to code this program in Python 3.5 I would be most grateful. (PS. please do try to use as few nested loops as possible, none would be ideal).

I have attached my code below for reference. If it is possible to modify my current code to reduce its time complexity, please do inform me as well.

import collections

import itertools


def selection(a, b, c, d):
    """program has 4 main parts,
        firstly, triple_sum finds possible 3sums in the list abc
        secondly, is_iterable ensures that inputs of tuple length 1 (ie. not iterable) are made iterable
        thirdly, main function determines if there exists possible a, b, c in abc that correspond to each d
        fourthly, main function checks if only 1 of the 3 integers from triple_sum exist in each array"""

    '''use sort O(n log n) to sort input array, then find possible 3sums in O(n^2)'''
    def triple_sum(a, res):
        a.sort()
        positions = collections.defaultdict(set)
        for i, n in enumerate(a):
            positions[n].add(i)
        for (i, ai), (j, aj) in itertools.combinations(enumerate(a), 2):
            n = res - ai - aj
            if positions[n].difference((i, j)):
                return n, ai, aj

    '''Ensure that all inputs are iterable'''
    def is_iterable(x):
        if isinstance(x, collections.Iterable):
            return x
        else:
            return x,

    a, b, c, d = is_iterable(a), is_iterable(b), is_iterable(c), is_iterable(d)
    abc = a + b + c
    abc = [i for i in abc]

    '''find value of d which has corresponding a, b, c
        and returns appropriate value if conditions are met'''
    ans_a, ans_b, ans_c, ans_d = 0, 0, 0, 0
    for i in d:
        x = 0 - i
        j = triple_sum(abc, x)
        if j[0] in a and j[1] in b and j[2] in c:
            ans_a, ans_b, ans_c, ans_d = j[0], j[1], j[2], i
            break
        elif j[0] in a and j[2] in b and j[1] in c:
            ans_a, ans_b, ans_c, ans_d = j[0], j[2], j[1], i
            break
        elif j[1] in a and j[0] in b and j[2] in c:
            ans_a, ans_b, ans_c, ans_d = j[1], j[0], j[2], i
            break
        elif j[1] in a and j[2] in b and j[0] in c:
            ans_a, ans_b, ans_c, ans_d = j[1], j[2], j[0], i
            break
        elif j[2] in a and j[0] in b and j[1] in c:
            ans_a, ans_b, ans_c, ans_d = j[2], j[0], j[1], i
            break
        elif j[2] in a and j[1] in b and j[0] in c:
            ans_a, ans_b, ans_c, ans_d = j[2], j[1], j[0], i
            break
        else:
            continue

    return ans_a, ans_b, ans_c, ans_d

Thanks in advance :)

PS. If anyone needs more clarification or information do let me know.


Solution

  • Theory

    • Iterate over all the (i,j) pairs and create a dict with i+j as key and (i,j) as value

    • Iterate over all the (k,l) pairs and create a dict with -(k+l) as key and (k,l) as value.

    • Iterate over the keys of your first dict, and check if the second dict also has this key. In that case, sum((i,j,k,l)) will be 0.

    Every described step is O(n²), so the whole algorithm will be O(n²), with n being the size of the tuples.

    Code

    a = [1, 2, 3, 4]
    b = [5, 6, 7, 8]
    c = [9, 10, 11, 12]
    d = [13, 14, 15, -24]
    
    
    def pairs_and_sums(list1, list2):
        return {(x + y): (x, y) for x in list1 for y in list2}
    
    
    first_half = pairs_and_sums(a, b)
    second_half = pairs_and_sums(c, d)
    
    for i_plus_j in first_half:
        if -i_plus_j in second_half:
            i, j = first_half.get(i_plus_j)
            k, l = second_half.get(-i_plus_j)
            print("Solution found!")
            print("sum((%d,%d,%d,%d)) == 0" % (i, j, k, l))
            break
    else:
        print "No solution found"
    

    It outputs :

    Solution found!
    sum((4,8,12,-24)) == 0