Search code examples
pythonnumpyarray-broadcasting

Efficiently getting values from numpy matrix


I have a matrix: a = np.random.randn(10,3), and locations: locs = [[6, 6, 0], [7, 0, 5], [0, 9, 2]]. I need to replace the values in the first column of a at locations described by the first column of loc with a placeholder (say 0 or np.inf). Similarly, the locations described by the second column of loc in a must be replaced by the placeholder.

I have written it as a loop currently:

for i in range(0, locs.shape[1]):
    l = locs[:, i]
    current = a[:, i]
    current[l] = 0
    a[:, i] = current

Is it possible to replace this loop with direct numpy operations?

Example input:

a = 
[[ 1.43734578  0.09736638 -2.25746086]
 [-0.76353825  0.29902121 -0.47547664]
 [-0.6702289  -1.03620696  1.29729398]
 [-0.01606927  0.03169479  0.32413694]
 [-0.87992136 -0.13887237  0.76943651]
 [-0.99176294  1.2174871  -0.04219437]
 [-1.28379798 -2.05605769  0.30146702]
 [-0.7249709   0.12472804  0.43728411]
 [ 0.04843567  0.85251779 -0.12717516]
 [ 0.13927597  2.06447447 -0.74675081]]
locs = 
[[6 6 0]
 [7 0 5]
 [0 9 2]]

The expected output is:

output = 
array([[ 0.        ,  0.        ,  0.        ],
       [-0.76353825,  0.29902121, -0.47547664],
       [-0.6702289 , -1.03620696,  0.        ],
       [-0.01606927,  0.03169479,  0.32413694],
       [-0.87992136, -0.13887237,  0.76943651],
       [-0.99176294,  1.2174871 ,  0.        ],
       [ 0.        ,  0.        ,  0.30146702],
       [ 0.        ,  0.12472804,  0.43728411],
       [ 0.04843567,  0.85251779, -0.12717516],
       [ 0.13927597,  0.        , -0.74675081]])

Solution

  • Fancy indexing should work here:

    a[loc,np.arange(3)] = placeholder
    

    Example:

    >>> rng = np.random.default_rng()
    >>> a = np.arange(30).reshape(10,3)
    >>> b = rng.integers(0,10,(3,3))
    >>> b
    array([[1, 3, 4],
           [5, 7, 1],
           [1, 7, 9]])
    >>> a[b,np.arange(3)] = 100
    >>> a
    array([[  0,   1,   2],
           [100,   4, 100],
           [  6,   7,   8],
           [  9, 100,  11],
           [ 12,  13, 100],
           [100,  16,  17],
           [ 18,  19,  20],
           [ 21, 100,  23],
           [ 24,  25,  26],
           [ 27,  28, 100]])