Search code examples
pythonpermutationcombinatorics

Permutations in Python with different options for each position


I'm trying to write a solver for a sudoku-like puzzle, and at some point I want to find all possible combinations for a row. A row must contain digits from 1 to 6 with no duplicates. In general, there would be 6! = 720 combinations for a row. But since I already narrowed down my options, the list of possible choices would look something like this:

row = [[1, 2, 3, 4], [1, 2, 3], [1, 2, 3, 6], [1, 2, 4], [1, 3, 4, 5, 6], [2, 3, 4, 5]]

Or like this, with some numbers already found:

row = [[1, 2, 3], [1, 2, 3], 5, 6, [1, 3, 4], [2, 3, 4]]

The solution I came up with involves using itertools product function to get all possible combinations with duplicates and then sifting through the resulting list to remove all options with duplicates:

from itertools import product

def get_permutations(row):
    # this line is needed to turn ints (digits that are already found)
    # into sets with the length of 1:
    row = [s.copy() if isinstance(s, set) else {s} for s in row]

    # this is where we get all possible permutations with duplicates:
    prod = product(*row)
    # in this example, we have 2880 permutations with duplicates

    without_duplicates = [value for value in prod if len(value) == len(set(value))]
    # without duplicates, there's only 32 variations
    return without_duplicates

r = [{1, 2, 3, 4}, {1, 2, 3}, {1, 2, 3, 6}, {1, 2, 4}, {1, 3, 4, 5, 6}, {2, 3, 4, 5}]
perm = get_permutations(r)

The problem with this method is that we create thousands of permutations only to narrow them down to a few dozen. It works relatively well for a 6x6 puzzle, but will be too inefficient for bigger ones, where thousands turn into millions. So maybe there's a way not to include duplicates in our list to begin with.


Solution

  • If you incrementally build up your cartesian products, you can strip out the duplicate entries before they grow exponentially. Here is a code snippet doing exactly that:

    from itertools import product
    from collections.abc import Iterable
    
    def uniq_prod(sets):
        result = [()]
        
        for selection in sets:
            result = [tup + (x,) for tup in result for x in selection.difference(tup)]
            
        return result
    
    def get_perms_fast(row):
        if not row:
            return []
        return uniq_prod([set(s) if isinstance(s, Iterable) else {s} for s in row])
    

    Running it on your example input:

    >>> row = [{1, 2, 3, 4}, {1, 2, 3}, {1, 2, 3, 6}, {1, 2, 4}, {1, 3, 4, 5, 6}, {2, 3, 4, 5}]
    >>> print(get_perms_fast(row))
    [(1, 2, 3, 4, 6, 5), (1, 2, 6, 4, 3, 5), (1, 2, 6, 4, 5, 3), (1, 3, 2, 4, 6, 5), (1, 3, 6, 2, 4, 5), (1, 3, 6, 2, 5, 4), (1, 3, 6, 4, 5, 2), (2, 1, 3, 4, 6, 5), (2, 1, 6, 4, 3, 5), (2, 1, 6, 4, 5, 3), (2, 3, 1, 4, 6, 5), (2, 3, 6, 1, 4, 5), (2, 3, 6, 1, 5, 4), (2, 3, 6, 4, 1, 5), (3, 1, 2, 4, 6, 5), (3, 1, 6, 2, 4, 5), (3, 1, 6, 2, 5, 4), (3, 1, 6, 4, 5, 2), (3, 2, 1, 4, 6, 5), (3, 2, 6, 1, 4, 5), (3, 2, 6, 1, 5, 4), (3, 2, 6, 4, 1, 5), (4, 1, 3, 2, 6, 5), (4, 1, 6, 2, 3, 5), (4, 1, 6, 2, 5, 3), (4, 2, 3, 1, 6, 5), (4, 2, 6, 1, 3, 5), (4, 2, 6, 1, 5, 3), (4, 3, 1, 2, 6, 5), (4, 3, 2, 1, 6, 5), (4, 3, 6, 1, 5, 2), (4, 3, 6, 2, 1, 5)]