Search code examples
numpypermutation

Applying a condition to rows of a 2d array


Consider a = np.array([0, 0, 1, 1, 2, 2, 3, 3])

In the multiset a, there are exactly 2 instances each of 0,1,2, and 3.

I want to find all permutations of a that meet a condition as we move through each row from left to right:

condition: the 1st instance of 0,1,2, and 3 must appear in that order, though they do not need to be consecutive.

[0, 1, , 2, , 3, , ] is ok, [0, 1, , 3, , 2, , ] is not ok

The 2nd instance of each number may appear anywhere in the row as long as it is after (to the right of) the 1st instance.

[0, 1, 0, 2, 2, 3, 1, 3] is ok

I've started by finding all 8!/2**4 = 2525 permutations of the multiset a:

from sympy.utilities.iterables import multiset_permutations
import numpy as np

a = np.array([0, 0, 1, 1, 2, 2, 3, 3])

resultList = []
for p in multiset_permutations(a):
    resultList.append(p)
    
out = np.array(resultList)

My difficulty is that I'm drowning in the details when I try to set the condition. To compound the problem, the actual array a could have up to 5 pairs of values. QUESTION: How can the condition be written so that I can eliminate, from array out, all permutation rows that do not satisfy the condition?


Solution

  • Since you know your array consists of exactly pairs of the elements in np.arange(4), you can use np.argmax to check:

    max_values = np.max(a)
    uniques = np.arange(max_values + 1)
    # or you can just do
    # uniques = np.unique(a)
    
    resultList = []
    for p in multiset_permutations(a):
        idx = np.argmax(p==uniques[:,None], axis=1)
        if (idx[:-1] < idx[1:]).all():
            resultList.append(p)
    

    Then resultList would contains 420 permutations for 4 pairs; and 4725 for 5 pairs.