Search code examples
pythonloopscombinationscombinatoricspython-collections

Finding combinations that meet a threshold relation


Given a value for phi, theta, n_1, and n_2, I need to find all possible pairs (N_1, N_2) that meet the following criteria:

0 <= N_1 <= n_1
0 <= N_2 <= n_2
N_1 - phi * N_2 >= theta

What is the most efficient way to do this in Python? Obviously I could use two for loops -- iterating over all possible values for N_1 and N_2 (from the first two criteria), and saving only those pairs that meet the last criterion -- but this would be fairly inefficient.


Solution

  • You could use numpy and vectorization, something like it below

    import numpy as np
    
    phi = 0.5
    theta = 1
    n1 = 10
    n2 = 20
    
    N1 = np.random.randint(-100, 100, size=100)
    N2 = np.random.randint(-100, 100, size=100)
    
    N1 = N1[(N1 >= 0) & (N1 <= n1)]
    N2 = N2[(N2 >= 0) & (N2 <= n2)]
    
    a = N2 * theta + phi
    res = N1.reshape(N1.shape[0], 1) - a.reshape(1, a.shape[0])
    
    indices = np.argwhere(res >= 0)
    pairs = zip(N1[indices[:,0]], N2[indices[:,1]])
    

    example output of pairs

    [(8, 3),
     (8, 6),
     (8, 5),
     (8, 1),
     (3, 1),
     (9, 3),
     (9, 8),
     (9, 8),
     (9, 6),
     (9, 5),
     (9, 6),
     (9, 6),
     (9, 5),
     (9, 8),
     (9, 1)]
    

    per @dbliss request, here is the modualized version and its test

    import numpy as np
    
    
    def calc_combination(N1, N2, n1, n2, theta, phi):
        N1 = N1[(N1 >= 0) & (N1 <= n1)]
        N2 = N2[(N2 >= 0) & (N2 <= n2)]
    
        a = N2 * theta + phi
        res = N1.reshape(N1.shape[0], 1) - a.reshape(1, a.shape[0])
    
        indices = np.argwhere(res >= 0)
        pairs = zip(N1[indices[:,0]], N2[indices[:,1]])
        return pairs
    
    
    def test_case():
        n1 = 5
        n2 = 1
        theta = 2
        phi = 2
    
        N1 = np.arange(n1 + 1)
        N2 = np.arange(n2 + 1)
    
        assert (calc_combination(N1, N2, n1, n2, theta, phi) ==
                [(2, 0), (3, 0), (4, 0), (4, 1), (5, 0), (5, 1)])
    
    test_case()