Search code examples
pythonarraysnumpyvectorization

Select N random rows from a 3D numpy array


I have a 3D array that I want to take random 'sets' (note: not a pythonic set) from axis 1, N times. I can achieve this via nested For loops, but I will need to do this at least 10000 times, so I need to find a vectorised solution if possible.

I will try to explain this using an example. If I want to retrieve N sets of data, I want to select one random index from axis 1 in my 3D array, for every element in axis 0. E.g. In the first of my N sets I randomly select indices [0, 2, 1], this correlates to the three different array positions: [0, 0, :], [1, 2, :], and [2, 1, :], respectively (i.e. axis 0 increments by one each time, and axis 1 is based on the randomly selected indices).

Below is a numerical example in pseudo-code:

# Create some arbitrary data (EDIT: based on mozway's answer)
a = array([[[ 0. ,  4. ,  8. , 12. , 16. , 20. , 24. ],
            [ 1. ,  5. ,  9. , 13. , 17. , 21. , 25. ],
            [ 2. ,  6. , 10. , 14. , 18. , 22. , 26. ],
            [ 3. ,  7. , 11. , 15. , 19. , 23. , 27. ]],

           [[ 0.1,  4.1,  8.1, 12.1, 16.1, 20.1, 24.1],
            [ 1.1,  5.1,  9.1, 13.1, 17.1, 21.1, 25.1],
            [ 2.1,  6.1, 10.1, 14.1, 18.1, 22.1, 26.1],
            [ 3.1,  7.1, 11.1, 15.1, 19.1, 23.1, 27.1]],

           [[ 0.2,  4.2,  8.2, 12.2, 16.2, 20.2, 24.2],
            [ 1.2,  5.2,  9.2, 13.2, 17.2, 21.2, 25.2],
            [ 2.2,  6.2, 10.2, 14.2, 18.2, 22.2, 26.2],
            [ 3.2,  7.2, 11.2, 15.2, 19.2, 23.2, 27.2]]])


# Define the number of requested sets
N = 2

# Define the chosen data per 'set' (normally would be random)
idx = [[0, 2, 1], [1, 3, 3]]

# First set would give (with choices [0, 2, 1]):
arr = [[ 0. ,  4. , 8.  , 12. , 16. , 20. , 24. ],
       [ 2.1,  6.1, 10.1, 14.1, 18.1, 22.1, 26.1],
       [ 1.2,  5.2, 9.2 , 13.2, 17.2, 21.2, 25.2]]

# Second set would give (with choices [1, 3, 3]):
arr = [[ 1. ,  5. ,  9. , 13. , 17. , 21. , 25. ],
       [ 3.1,  7.1, 11.1, 15.1, 19.1, 23.1, 27.1],
       [ 3.2,  7.2, 11.2, 15.2, 19.2, 23.2, 27.2]]

# So, the final output would combine all sets:
arr = [[[ 0. ,  4. , 8.  , 12. , 16. , 20. , 24. ],
        [ 2.1,  6.1, 10.1, 14.1, 18.1, 22.1, 26.1],
        [ 1.2,  5.2, 9.2 , 13.2, 17.2, 21.2, 25.2]],

        [ 1. ,  5. ,  9. , 13. , 17. , 21. , 25. ],
        [ 3.1,  7.1, 11.1, 15.1, 19.1, 23.1, 27.1],
        [ 3.2,  7.2, 11.2, 15.2, 19.2, 23.2, 27.2]]]

Solution

  • Given the clarifications of your question, you want to select N random rows in a 3D array on axis 1 (second dimension), but independently on axis 0:

    Let's call a the array and x,y,z its 3 dimensions.

    An easy way is to select N*x random indices so that there is N per x. Then flatten the array on the first 2 dimensions and slice.

    Example input (note the x/x.1/x.2 to track the originating dimension):

    array([[[ 0. ,  4. ,  8. , 12. , 16. , 20. , 24. ],
            [ 1. ,  5. ,  9. , 13. , 17. , 21. , 25. ],
            [ 2. ,  6. , 10. , 14. , 18. , 22. , 26. ],
            [ 3. ,  7. , 11. , 15. , 19. , 23. , 27. ]],
    
           [[ 0.1,  4.1,  8.1, 12.1, 16.1, 20.1, 24.1],
            [ 1.1,  5.1,  9.1, 13.1, 17.1, 21.1, 25.1],
            [ 2.1,  6.1, 10.1, 14.1, 18.1, 22.1, 26.1],
            [ 3.1,  7.1, 11.1, 15.1, 19.1, 23.1, 27.1]],
    
           [[ 0.2,  4.2,  8.2, 12.2, 16.2, 20.2, 24.2],
            [ 1.2,  5.2,  9.2, 13.2, 17.2, 21.2, 25.2],
            [ 2.2,  6.2, 10.2, 14.2, 18.2, 22.2, 26.2],
            [ 3.2,  7.2, 11.2, 15.2, 19.2, 23.2, 27.2]]])
    

    Processing:

    N = 2
    # sample with repeats
    idx = np.random.randint(y, size=N*x)
    corr = np.repeat(np.arange(0,(x-1)*y+1, y), N)
    idx += corr
    # sample without repeats
    idx = np.concatenate([np.random.choice(list(range(y)), replace=False, size=N)+(i*y) for i in range(x)])
    # slice array
    a.reshape(x*y,z)[idx].reshape(x,N,z).swapaxes(0,1)
    

    possible output (N,x,z) shape:

    array([[[ 0. ,  4. ,  8. , 12. , 16. , 20. , 24. ],
            [ 1.1,  5.1,  9.1, 13.1, 17.1, 21.1, 25.1],
            [ 0.2,  4.2,  8.2, 12.2, 16.2, 20.2, 24.2]],
    
           [[ 3. ,  7. , 11. , 15. , 19. , 23. , 27. ],
            [ 3.1,  7.1, 11.1, 15.1, 19.1, 23.1, 27.1],
            [ 1.2,  5.2,  9.2, 13.2, 17.2, 21.2, 25.2]]])