Search code examples
pythonarraysnumpyarray-broadcasting

np.where for 2d array, manipulate whole rows


I want to rebuild the following logic with numpy broadcasting function such as np.where: From a 2d array check per row if the first element satisfies a condition. If the condition is true then return the first three elements as a row, else the last three elements.

A short MWE in form of a for-loop which I want to circumvent:

import numpy as np
array = np.array([
    [1, 2, 3, 4],
    [1, 2, 4, 2],
    [2, 3, 4, 6]
])

new_array = np.zeros((array.shape[0], array.shape[1]-1))
for i, row in enumerate(array):
    if row[0] == 1: new_array[i] = row[:3]
    else: new_array[i] = row[-3:]

Solution

  • If you want to use np.where:

    import numpy as np
    array = np.array([
        [1, 2, 3, 4],
        [1, 2, 4, 2],
        [2, 3, 4, 6]
    ])
    
    cond = array[:, 0] == 1
    np.where(cond[:, None], array[:,:3], array[:,-3:])
    

    output:

    array([[1, 2, 3],
           [1, 2, 4],
           [3, 4, 6]])
    

    EDIT

    slightly more concise version:

    np.where(array[:, [0]] == 1, array[:,:3], array[:,-3:])