Search code examples
pythonpandasnumpynumpy-ndarray

Python - Numpy or Pandas(also possible) broadcasting


I have numpy array-2D (pandas DatarFame can be also used) full of numbers and I need to create / replace those numbers with mean for last n rows in one column. I have huge numpy array.shape like [10000:10000]

Example (limited shape just for explanation):

Numpy Array:

[[10, 30, 8, 1],
 [11, 5, 19, 12],
 [12, 18, 15, 6],
 [13, 10, 21, 9],
 [14, 67, 14, 2],
 [15, 13, 12, 6]]

Average by n = 3

So the code should take last 3 numbers in iteration and crate average

Numpy Array:

[[12.5, 23.5, 14.83333333, 5.833333333],
 [12, 10.33333333, 18.33333333, 9],
 [13, 31.66666667, 16.66666667, 5.666666667],
 [14, 30, 15.66666667, 5.333333333]]

Explanation:

  • 14 is average of numbers 15,14,13
  • 18.33333333 is average of numbers 21, 15, 19
  • 9 is average of numbers 9, 6, 12

Result should be that function takes n-last values in column dimension and make average of it.

I was able to do it through 2 for loops and standard python code, but it takes a lot of time.


Solution

  • You don't need to loop over your data. With Pandas, you can do a rolling_mean:

    import pandas as pd
    import numpy as np
    
    arr = np.array([[10, 30,  8,  1],
                    [11,  5, 19, 12],
                    [12, 18, 15,  6],
                    [13, 10, 21,  9],
                    [14, 67, 14,  2],
                    [15, 13, 12,  6]])
    
    n = 3
    df = pd.DataFrame(arr)
    out = df.rolling(n).mean().iloc[n-1:]
    print(out)
    
    # Output
          0          1          2         3
    2  11.0  17.666667  14.000000  6.333333
    3  12.0  11.000000  18.333333  9.000000
    4  13.0  31.666667  16.666667  5.666667
    5  14.0  30.000000  15.666667  5.666667
    

    With numpy only, you can do:

    # Adapted from https://stackoverflow.com/q/14313510/15239951
    out = np.cumsum(arr, axis=0)
    out[n:] -= out[:-n]
    out = out[n-1:] / n
    print(out)
    
    # Output
    array([[11.        , 17.66666667, 14.        ,  6.33333333],
           [12.        , 11.        , 18.33333333,  9.        ],
           [13.        , 31.66666667, 16.66666667,  5.66666667],
           [14.        , 30.        , 15.66666667,  5.66666667]])