Search code examples
pythonalgorithmcombinations

Get the number of all possible combinations that give a certain product


I'm looking for a algorithm that counts number of all possible combinations that gives certain product. I have a list of perfect squares [1,4,9,16,..,n] and I have two values a, b where a - number of elements that we can use in multiplication to get perfect square, b - maximum value of each element that can be used in multiplication For example, if a = 3 and b = 6 then for perfect square 36 we can have combinations such as [1,6,6], [2,3,6], [4,3,3] and so on (order matters [1,6,6] and [6,6,1] are different). NOTE we can not use [1,9,4] combination because 9 > b

I've tried to use combinations of all divisors for each perfect square from itertools and after that I checked each product of combination and if x1x2x3 == 36 I added 1 to my counting for perfect square = 36. This algorithm works but it requires significant amount of time for long multiplication.

Can we make it faster than look at each combination for each perfect square?

def get_divisors(n):
    result = []
    for i in range(1, n//2 + 1):
        if n % i == 0:
            result.append(i)
    result.append(n)
    return result
a = 2
b = 3
count = 0
all_squares = [1,4,9]
for i in all_squares:
    divisors = get_divisors(i)
    for r in divisors:
        if r > b:
            divisors.remove(r)
    for j in (itertools.product(divisors, repeat=a)):
        if numpy.prod(j) == i:
            count += 1
print(count)

Solution

  • Here's a mostly straightforward solution using recursive generators. No element larger than b shows up because no element so large is ever considered. Duplicates never show up because, by construction, this yields lists in non-increasing order of elements (i.e., in lexicographically reverse order).

    I don't know what squares have to do with this. The function couldn't care less whether target is a square, and you didn't appear to make any use of that property in what you wrote.

    def solve(target, a, b):
        for largest in range(min(target, b), 1, -1):
            result = []
            rem = target
            # Try one or more factors of `largest`, followed by
            # solutions for what remains using factors strictly less
            # than `largest`.
            while not rem % largest and len(result) < a:
                rem //= largest
                assert rem >= 1
                result.append(largest)
                if rem == 1:
                     yield result + [1] * (a - len(result))
                else:
                    for sub in solve(rem,
                                     a - len(result),
                                     largest - 1):
                        yield result + sub
    

    and then, e.g.,

    >>> for x in solve(36, 3, 6):
    ...     print(x)
    [6, 3, 2]
    [6, 6, 1]
    [4, 3, 3]
    >>> for x in solve(720, 4, 10):
    ...     print(x)
    [10, 9, 8, 1]
    [10, 9, 4, 2]
    [10, 8, 3, 3]
    [10, 6, 4, 3]
    [10, 6, 6, 2]
    [9, 8, 5, 2]
    [9, 5, 4, 4]
    [8, 6, 5, 3]
    [6, 6, 5, 4]
    

    If target can grow large, the easiest "optimization" to this would be to precoumpute all its non-unit factors <= b, and restrict the "for largest in ..." loop to look only at those.

    To just count the number, use the above like so:

    >>> sum(1 for x in solve(36, 3, 6))
    3
    >>> sum(1 for x in solve(720, 4, 10))
    9
    

    EDIT

    BTW, as things go on, mounds of futile search can be saved by adding this at the start:

        if b ** a < target:
            return
    

    However, whether that speeds or slows things overall depends on the expected characteristics of the inputs.

    OPTIMIZING

    Running the above:

    >>> sum(1 for x in solve(720000000000, 20, 10000))
    4602398
    

    That took close to 30 seconds on my box. What if we changed the function to just return the count, but not the solutions?

    def solve(target, a, b):
        count = 0
        for largest in range(min(target, b), 1, -1):
            nres = 0
            rem = target
            while not rem % largest and nres < a:
                rem //= largest
                assert rem >= 1
                nres += 1
                if rem == 1:
                    count += 1
                else:
                    sub = solve(rem,
                                a - nres,
                                largest - 1)
                    count += sub
        return count
    

    Returns the same total, but took closer to 20 seconds.

    Now what if we memoized the function? Just add two lines before the def;

    import functools
    @functools.cache
    

    Doesn't change the result, but cuts the time to under half a second.

    >>> solve.cache_info()
    CacheInfo(hits=459534, misses=33755, maxsize=None, currsize=33755)
    

    So the bulk of the recursive calls were duplicates of calls already resolved, and were satisfied by the hidden cache without executing the body of the function. OTOH, we're burning memory for a hidden dict holding about 32K entries.

    Season to taste. As a comment noted, the usual alternative to memoizing a recursive function is to do heavier thinking to come up with a "dynamic programming" approach that builds possible results "from the bottom up", typically using explicit lists instead of hidden dicts.

    I won't do that here, though. It's already thousands of times faster than I have any actual use for ;-)

    Computing # of permutations from a canonical

    Here's code for that:

    from math import factorial, prod
    from collections import Counter
    def numperms(canonical):
        c = Counter(canonical)
        return (factorial(c.total())
                // prod(map(factorial, c.values())))
    
    >>> numperms([1, 6, 6])
    3
    >>> numperms("mississippi")
    34650
    

    Counting distinct permutations

    This may be what you really want. The code is simpler because it's not being "clever" at all, treating all positions exactly the same (so, e.g., all 3 distinct ways of permuting [6, 6, 1] are considered to be different).

    You absolutely need to memoize this if you want it to complete a run in your lifetime ;-)

    import functools
    @functools.cache
    def dimple(target, a, b):
        if target == 1:
            return 1
        if a <= 0:
            return 0
        count = 0
        for i in range(1, min(target, b) + 1):
            if not target % i:
                count += dimple(target // i, a - 1, b)
        return count
    
    >>> dimple(720000000000, 20, 10000)
    1435774778817558060
    >>> dimple.cache_info()
    CacheInfo(hits=409566, misses=8788, maxsize=None, currsize=8788)
    >>> count = 0
    >>> for x in solve(720000000000, 20, 10000):
    ...     count += numperms(x)
    >>> count
    1435774778817558060
    

    The second way uses the original function, and applies the earlier numperms() code to deduce how many distinct permutations each canonical solution corresponds to. It doesn't benefit from caching, though, so is very much slower. The first way took several seconds under CPython, but 1 under PyPy.

    Of course the universe will end before any approach using itertools to consider all possible arrangements of divisors could compute a result so large.

    Moving to DP

    For pedagogical purposes, I think it's worth the effort to show a dynamic programming approach too. The result here is easily the fastest of the bunch, which is typical of a DP approach. It's also typical - alas - that it takes more up-front analysis.

    This is much like recursive memoization, but "bottom up", starting with the simplest cases and then iteratively building them up to fancier cases. But we're not waiting for runtime recursion to figure out which simpler cases are needed to achieve a fancier result - that's all analyzed in advance.

    In the dimple() code, b is invariant. What remains can be viewed as entries in a large array, M[i, a], which gives the number of ways to express integer i as a product of a factors (all <= b).

    There is only one non-zero entry in "row" 0: M[1, 0] = 1. The empty product (1) is always achievable.

    For row 1, what can we obtain in one step? Well, every divisor of target can be reduced to the row 0 case by dividing it by itself. So row 1 has a non-zero entry for every divisor of target <= b, with value 1 (there's only one way to get it).

    Row 2? Consider, e.g., M[6, 2]. 6 can be gotten from row 1 via multiplying 1 by 6, or 6 by 1, or 2 by 3, or 3 by 2, so M[6, 2] = M[1, 1] + M[2, 1] + M[3, 1] + M[6, 1] = 4,

    And so on. With the right preparation, the body of the main loop is very simple.

    Note that each row i's values depend only on row i-1, so saving the whole array isn't needed. The code only saves the most recent row, and builds the next row from it. In fact, it's often possible to update to the next row "in place", so that memory for only one row is needed. But offhand I didn't see an efficient way to do that in this particular case.

    In addition, since only divisors of target are achievable, most of a dense M would consist of 0 entries. So, instead, a row is represented as a defaultdict(int), so storage is only needed for the non-zero entries.

    Most of the code here is in utility routines to precompute the possible divisors.

    EDIT: simplified the inner loop by precomputing invariants even before the outer loop starts.

    from itertools import chain, count, islice
    from collections import Counter, defaultdict
    from math import isqrt
    
    # Return a multiset with n's prime factorization, mapping a prime
    # to its multiplicity; e.g,, factor(12) is {2: 2, 3: 1}.
    def factor(n):
        result = Counter()
        s = isqrt(n)
        for p in chain([2], count(3, 2)):
            if p > s:
                break
            if not n % p:
                num = 1
                n //= p
                while not n % p:
                    num += 1
                    n //= p
                result[p] = num
                s = isqrt(n)
        if n > 1:
            result[n] = 1
        return result
    
    # Return a list of all the positive integer divisors of an int. `ms`
    # is the prime factorization of the int, as returned by factor().
    # Nothing is guaranteed about the order of the list. For example,
    # alldivs(factor(12)) is a permutation of [1, 2, 3, 4, 6, 12].
    # There are prod(e + 1 for e in ms.values()) entries in the list.
    def alldivs(ms):
        result = [1]
        for p, e in ms.items():
            # NB: advanced trickery here. `p*v` is applied to entries
            # appended to `result` _whlle_ the extend() is being
            # executed. This is well defined, but unusual. It saves
            # layers of otherwise-needed explicit indexing and/or loop
            # nesting. For example, if `result` starts as [2, 3] and
            # e=2, this leaves result as [2, 3, p*2, p*3, p*p*2, p*p*3].
            result.extend(p * v
                          for v in islice(result,
                                          len(result) * e))
        return result
    
    def dimple_dp(target, a, b):
        target_ms = factor(target)
        if max(target_ms) > b:
            return 0
        divs = alldivs(target_ms)
        smalldivs = sorted(d for d in divs if d <= b)
    
        # div2mults maps a divisor `d` to a list of all divisors
        # of the form `s*d` where `s` in smalldivs.
        divs = set(divs)
        div2mults = {}
        for div in divs:
            mults = div2mults[div] = []
            for s in smalldivs:
                mult = s * div
                if mult in divs:
                    mults.append(mult)
                elif mult > target:
                    break
        del divs, div, mults, s, mult
    
        # row 0 has 1 entry: 1 way to get the empty product
        row = defaultdict(int)
        row[1] = 1
        # Compute rows 1 through a-1. Since the only entry we want from
        # row `a` is row[target], we save a bit of time by stopping
        # here at row a-1 instead.
        for _ in range(1, a):
            newrow = defaultdict(int)
            for div, count in row.items():
                for mult in div2mults[div]:
                    newrow[mult] += count
            row = newrow
        return sum(row[target // d] for d in smalldivs)
    
    # under a second even in CPython
    >>> dimple_dp(720000000000, 20, 10000)
    1435774778817558060
    

    Note that the amount of memory needed is proportional to the number of divisors of target, and is independent of a. When memoizing a recursive function, the cache never forgets anything (unless you implement and manage it yourself "by hand").