Search code examples
pythonnumpyvectorization

Choose a random element in each row of a 2D array but only consider the elements based on a given mask python


I have a 2D array data and a boolean array mask of shapes (M,N). I need to randomly pick an element in each row of data. However, the element I picked should be true in the given mask. Is there a way to do this without looping over every row? In every row, there are at least 2 elements for which mask is true.

Minimum Working Example:

data = numpy.arange(8).reshape((2,4))
mask = numpy.array([[True, True, True, True], [True, True, False, False]])
selected_data = numpy.random.choice(data, mask, num_elements=1, axis=1)

The 3rd line above doesn't work. I want something like that. I've listed below some valid solutions.

selected_data = [0,4]
selected_data = [1,5]
selected_data = [2,5]
selected_data = [3,4]

Solution

  • It is easier to work with the indices of the mask. We can get the indices of the True values from the mask and stack them together to create 2D coordinates array. All of the values inside the indices2d are possible to sample. Then we can shuffle the array and get the first index of the unique row values. Since the array is shuffled, it is random choice. Then we can match the selected 2D indices to the original data. See below;

    import numpy
    
    data = numpy.arange(8).reshape((2,4))
    mask = numpy.array([[True, True, True, True], [True, True, False, False]])
    
    for _ in range(20):
        indices2d = numpy.dstack(numpy.where(mask)).squeeze().astype(numpy.int32)
        numpy.random.shuffle(indices2d)
        randomElements = indices2d[numpy.unique(indices2d[:, 0], return_index=True)[1]]
        print(data[randomElements[:,0],randomElements[:,1]])
    

    Output

    [0 5]
    [1 4]
    [0 5]
    [1 4]
    [1 5]
    [0 5]
    [1 4]
    [1 5]
    [3 4]
    [2 5]
    [2 4]
    [3 5]
    [2 4]
    [0 4]
    [0 4]
    [0 4]
    [0 5]
    [3 5]
    [3 5]
    [1 4]
    

    12.7 ms ± 80.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)