A function f : X -> Y
is surjective when every element of Y
has at least one preimage in X
. When X = {0,...,m-1}
and Y = {0,...,n-1}
are two finite sets, then f
corresponds to an m
-tuple of numbers < n
, and it is surjective precisely when every number < n
appears at least once. (When we require that every number appears exactly once, we have n=m
and are talking about permutations.)
I would like to know an efficient algorithm for computing the set of all surjective tuples for two given numbers n
and m
. The number of these tuples can be computed very efficiently with the inclusion-exclusion principle (see for example here), but I don't think that this is useful here (since we would first compute all tuples and then remove the non-surjective ones step by step, and I assume that the computation of all tuples will take longer*.). A different approach goes as follows:
Consider for example the tuple
(1,6,4,2,1,6,0,2,5,1,3,2,3)
in which every number < 7 appears at least once. Look at the largest number and erase it:
(1,*,4,2,1,*,0,2,5,1,3,2,3)
It appears in the indices 1 and 5, so this corresponds to the set {1,5}
, a subset of the indices.
The rest corresponds to the tuple
(1,4,2,1,0,2,5,1,3,2,3)
with the property that every number < 6 appears at least once.
We see that the surjective m
-tuples of numbers < n
correspond to the pairs (T,a)
, where T
is a non-empty subset of {0,...,m-1}
and a
is a surjective (m-k)
-tuple of numbers < n-1
, where T
has k
elements.
This leads to the following recursive implementation (written in Python):
import itertools
def surjective_tuples(m: int, n: int) -> set[tuple]:
"""Set of all m-tuples of numbers < n where every number < n appears at least once.
Arguments:
m: length of the tuple
n: number of distinct values
"""
if n == 0:
return set() if m > 0 else {()}
if n > m:
return set()
result = set()
for k in range(1, m + 1):
smaller_tuples = surjective_tuples(m - k, n - 1)
subsets = itertools.combinations(range(m), k)
for subset in subsets:
for smaller_tuple in smaller_tuples:
my_tuple = []
count = 0
for i in range(m):
if i in subset:
my_tuple.append(n - 1)
count += 1
else:
my_tuple.append(smaller_tuple[i - count])
result.add(tuple(my_tuple))
return result
I noticed that this is quite slow, though, when the input numbers are large. For example when (m,n)=(10,6)
the computation takes 32
seconds on my (old) PC, the set has 16435440
elements here. I suspect that there is a faster algorithm.
*In fact, the following implementation is very slow.
def surjective_tuples_stupid(m: int, n: int) -> list[int]:
all_tuples = list(itertools.product(*(range(n) for _ in range(m))))
surjective_tuples = filter(lambda t: all(i in t for i in range(n)), all_tuples)
return list(surjective_tuples)
Just optimized yours a little, mainly by using insert
to build the tuple. About 5x faster than yours for m=9, n=7.
def surjective_tuples(m: int, n: int) -> list[tuple]:
"""List of all m-tuples of numbers < n where every number < n appears at least once.
Arguments:
m: length of the tuple
n: number of distinct values
"""
if not n:
return [] if m else [()]
if n > m:
return []
n -= 1
result = []
for k in range(1, m - n + 1):
smaller_tuples = surjective_tuples(m - k, n)
subsets = itertools.combinations(range(m), k)
for subset in subsets:
for smaller_tuple in smaller_tuples:
my_tuple = [*smaller_tuple]
for i in subset:
my_tuple.insert(i, n)
result.append(tuple(my_tuple))
return result