I'm trying to come up with an efficient means to generate all subsets of a set of integers, such that the subsets exceed some value N but do not contain any superfluous members.
That is to say, once a set exceeds N, no additional members should be added. Once you've identified a subset that exceeds N, that subset should not be included in any subsequent subsets.
Take for example the following set:
[1, 2, 5, 1, 3]
For a value of N = 6, the solution is:
[5, 2] [5, 1, 1] [5, 3] [3, 2, 1, 1]
These are all the sets that can be constructed such that no set contains a member that is unneeded to exceed N.
[5, 2, 1]
is not included, for example, because [5, 2]
already exceeds N, and thus any new set which also includes [5, 2]
as a subset is redundant.
I have the following code which can be used to generate all subsets which exceed some value N
from collections import Counter
def solve(nums, target):
counts = sorted(Counter(nums).items())
reserve = sum(nums) - target
if reserve <= 0:
return []
return list(_solve(counts, reserve, []))
def _solve(counts, reserve, prefix):
if not counts:
yield tuple(prefix)
return
val, max_count = last = counts.pop()
prefix.extend([val] * max_count)
yield from _solve(counts, reserve, prefix)
for count in range(1, max_count + 1):
prefix.pop()
if reserve - count * val > 0:
yield from _solve(counts, reserve - count * val, prefix)
counts.append(last)
However, I'm not sure how to modify this in way that efficiently eliminates or avoids calculating the unnecessary subsets (the one that exceed N and include a member not required to do so).
Here is a different way to solve this problem.
def find_subset_sums_greater_than_n(lst,n):
lst.sort(reverse=True)
all_subsets,subset = [],[]
for i in range(len(lst)):
subset_sum = lst[i]
subset.append(lst[i])
for j in range(i+1,len(lst)):
if subset_sum > n:
if len(all_subsets)==0 or subset != all_subsets[-1]:
all_subsets.append(subset[:])
if len(subset) > 1:
subset_sum -= subset[-1]
subset.pop()
else:
break # stop inner loop if lst[i] > N
subset.append(lst[j])
subset_sum+=lst[j]
if subset_sum > n and (len(all_subsets)==0 or subset != all_subsets[-1]) :
all_subsets.append(subset)
if sum(lst[i+1:]) <= n: # Stop early if sum of remaining elements doesn't exceed n
break
subset=[]
return all_subsets
Runtime Analysis:
The sorting takes O(nlogn)
, the outer for loop runs worst case n
times. For each iteration of outer loop inner loop runs worst case n
times and the sum function runs n
times worst case, so total for the algorithm is O(n^2)
. (n
is the length of the list).
The only two improvements are as follows:
When either N
is larger than the sum of several elements at the end of the list, since the outer loop will break early in this case. Using your example, for N = 8
, the algo stops once lst[i]=3
since 3+2+1+1 < 8
, so we save checking [3],[3,2],[3,2,1]...,[2],[2,1]...
and so on.
When N
is smaller then several elements at the beginning of the list. Here it will also skip several subsets by breaking the inner loop runs early, since we have exceeded N
already, so we can skip all subsets that contain more than just that number. Using your example, for N = 2
, we get [5]
and [3]
and then skip all subsets that are larger than length 1
and that contain 5
or 3
.
Despite these improvements, the runtime is still O(n^2)
since we can choose an N
that makes this algorithm approach the worst case scenario.