Search code examples
pythonalgorithmperformancetuplescombinatorics

Efficient computation of the set of surjective functions


A function f : X -> Y is surjective when every element of Y has at least one preimage in X. When X = {0,...,m-1} and Y = {0,...,n-1} are two finite sets, then f corresponds to an m-tuple of numbers < n, and it is surjective precisely when every number < n appears at least once. (When we require that every number appears exactly once, we have n=m and are talking about permutations.)

I would like to know an efficient algorithm for computing the set of all surjective tuples for two given numbers n and m. The number of these tuples can be computed very efficiently with the inclusion-exclusion principle (see for example here), but I don't think that this is useful here (since we would first compute all tuples and then remove the non-surjective ones step by step, and I assume that the computation of all tuples will take longer*.). A different approach goes as follows:

Consider for example the tuple

(1,6,4,2,1,6,0,2,5,1,3,2,3)

in which every number < 7 appears at least once. Look at the largest number and erase it:

(1,*,4,2,1,*,0,2,5,1,3,2,3)

It appears in the indices 1 and 5, so this corresponds to the set {1,5}, a subset of the indices. The rest corresponds to the tuple

(1,4,2,1,0,2,5,1,3,2,3)

with the property that every number < 6 appears at least once.

We see that the surjective m-tuples of numbers < n correspond to the pairs (T,a), where T is a non-empty subset of {0,...,m-1} and a is a surjective (m-k)-tuple of numbers < n-1, where T has k elements.

This leads to the following recursive implementation (written in Python):

import itertools


def surjective_tuples(m: int, n: int) -> set[tuple]:
    """Set of all m-tuples of numbers < n where every number < n appears at least once.

    Arguments:
        m: length of the tuple
        n: number of distinct values
    """
    if n == 0:
        return set() if m > 0 else {()}
    if n > m:
        return set()
    result = set()
    for k in range(1, m + 1):
        smaller_tuples = surjective_tuples(m - k, n - 1)
        subsets = itertools.combinations(range(m), k)
        for subset in subsets:
            for smaller_tuple in smaller_tuples:
                my_tuple = []
                count = 0
                for i in range(m):
                    if i in subset:
                        my_tuple.append(n - 1)
                        count += 1
                    else:
                        my_tuple.append(smaller_tuple[i - count])
                result.add(tuple(my_tuple))
    return result

I noticed that this is quite slow, though, when the input numbers are large. For example when (m,n)=(10,6) the computation takes 32 seconds on my (old) PC, the set has 16435440 elements here. I suspect that there is a faster algorithm.

*In fact, the following implementation is very slow.

def surjective_tuples_stupid(m: int, n: int) -> list[int]:
    all_tuples = list(itertools.product(*(range(n) for _ in range(m))))
    surjective_tuples = filter(lambda t: all(i in t for i in range(n)), all_tuples)
    return list(surjective_tuples)

Solution

  • Just optimized yours a little, mainly by using insert to build the tuple. About 5x faster than yours for m=9, n=7.

    def surjective_tuples(m: int, n: int) -> list[tuple]:
        """List of all m-tuples of numbers < n where every number < n appears at least once.
    
        Arguments:
            m: length of the tuple
            n: number of distinct values
        """
        if not n:
            return [] if m else [()]
        if n > m:
            return []
        n -= 1
        result = []
        for k in range(1, m - n + 1):
            smaller_tuples = surjective_tuples(m - k, n)
            subsets = itertools.combinations(range(m), k)
            for subset in subsets:
                for smaller_tuple in smaller_tuples:
                    my_tuple = [*smaller_tuple]
                    for i in subset:
                        my_tuple.insert(i, n)
                    result.append(tuple(my_tuple))
        return result