Search code examples
pythonnumpypython-itertools

Using itertools in Python to generate sums while also keeping track of each element of the sum?


I have 2 lists of numbers. For each list, I need to calculate all possible sums, then compare the sums generated to find matches. I need to output the sum and the elements from each list that comprise that sum. I want to retain all combinations if there was more than one possible combination to reach a given sum.

Example inputs:

a = [5, 12.5, 20]
b = [4, 13.5, 20]

Desired Output:

x = [17.5, 20, 37.5] #sums that matched, sorted ascending
a1 = [(5, 12.5),(20),(5, 12.5, 20)] #elements of a in solution sums
b1 = [(4, 13.5),(20),(4, 13.5, 20)] #elements of b in solution sums

This is what I've tried so far, which is finding all possible combinations, then all possible sums and comparing using a numpy array. This seems overcomplicated and only works if a and b are the same length, which they may not be.

import numpy as np
import itertools

a = [5, 12.5, 20]
b = [4, 13.5, 20]

acombos = [seq for i in range(len(a), 0, -1) for seq \
in itertools.combinations(a, i)]

asums = map(sum,acombos)

bcombos = [seq for i in range(len(b), 0, -1) for seq \
in itertools.combinations(b,i)]

bsums = map(sum,bcombos)

comboarray = np.array([acombos, asums, bcombos, bsums])

solutionarray = comboarray[:,comboarray[1] == comboarray[3]]
solutionarray = solutionarray[:, np.argsort(solutionarray[1])]

print(solutionarray)

Solution

  • For each input list, make a dictionary with sums for keys. Each sum contains a list of number sets that add up to the sum.

    Find the intersection of the sums for the two input lists. Then for each intersecting sum, print the numbers from the different lists that added up to those common sums.

    import itertools
    import collections
    
    
    def sums_dict(nums):
        d = collections.defaultdict(list)
        for k in range(1, len(nums) + 1):
            for c in itertools.combinations(nums, k):
                d[sum(c)].append(c)
        return d
    
    def print_matching_sums(a, b):
        ad = sums_dict(a)
        bd = sums_dict(b)
        sums = sorted(set(ad).intersection(bd))
        # The tuples are only to make the output a little easier to read
        print('sums = {}'.format(tuple(sums)))
        print('a = {}'.format(tuple(ad[s] for s in sums)))
        print('b = {}'.format(tuple(bd[s] for s in sums)))
    
    
    a = [5, 12.5, 20]
    b = [4, 13.5, 20]
    print_matching_sums(a, b)
    

    Output:

    sums = (17.5, 20, 37.5)
    a = ([(5, 12.5)], [(20,)], [(5, 12.5, 20)])
    b = ([(4, 13.5)], [(20,)], [(4, 13.5, 20)])