Search code examples
pythonpython-3.xnumpynumpy-ndarraynumpy-ufunc

Numpy apply along axis based on row index


Trying to apply numpy inbuilt function apply_along_axis based on row index position

import numpy as np
sa = np.array(np.arange(4))
sa_changed = (np.repeat(sa.reshape(1,len(sa)),repeats=2,axis=0))
print (sa_changed)

OP:

[[0 1 2 3]
 [0 1 2 3]]

The function:

np.apply_along_axis(lambda x: x+10,0,sa_changed)

Op:

array([[10, 11, 12, 13],
       [10, 11, 12, 13]])

But is there a way to use this function based on row index position for example, if its a even row index then add 10 and if its a odd row index then add 50

Sample:

def func(x):
   if x.index//2==0:
      x = x+10
   else:
      x = x+50
   return x

Solution

  • When iterating on array, directly or with apply_along_axis, the subarray does not have a .index attribute. So we have to pass an explicit index value to your function:

    In [248]: def func(i,x):
         ...:    if i//2==0:
         ...:       x = x+10
         ...:    else:
         ...:       x = x+50
         ...:    return x
         ...: 
    In [249]: arr = np.arange(10).reshape(5,2)
    

    apply doesn't have a way to add this index, so instead we have to use an explicit iteration.

    In [250]: np.array([func(i,v) for i,v in enumerate(arr)])
    Out[250]: 
    array([[10, 11],
           [12, 13],
           [54, 55],
           [56, 57],
           [58, 59]])
    

    replacing // with %

    In [251]: def func(i,x):
         ...:    if i%2==0:
         ...:       x = x+10
         ...:    else:
         ...:       x = x+50
         ...:    return x
         ...: 
    In [252]: np.array([func(i,v) for i,v in enumerate(arr)])
    Out[252]: 
    array([[10, 11],
           [52, 53],
           [14, 15],
           [56, 57],
           [18, 19]])
    

    But a better way is to skip the iteration entirely:

    Make an array of the row additions:

    In [253]: np.where(np.arange(5)%2,10,50)
    Out[253]: array([50, 10, 50, 10, 50])
    

    apply it via broadcasting:

    In [256]: x+np.where(np.arange(5)%2,50,10)[:,None]
    Out[256]: 
    array([[10, 11],
           [52, 53],
           [14, 15],
           [56, 57],
           [18, 19]])