Search code examples
pythonpandaslambdaapplyrolling-computation

Calculate an array of rolling correlations with different shifts of the input


I will illustrate the challenge with a small set of dummy data. Suppose the data is in a dataframe with two columns, like this:

data = pd.DataFrame(data=[[0.1, 0.4], [0.6, 0.9], [0.5, 0.3], [0.3, 0.2], [0.1, 0.2], [0.9, 0.5], [0.5, 0.6]], index=['2021-01-01', '2021-01-02', '2021-01-03', '2021-01-04', '2021-01-05', '2021-01-06', '2021-01-07'], columns=['x', 'y'])

I am trying to build a dataframe that shows the 3-day rolling correlation between two series with different shifts of 'x' (up to two shifts in this example). The result dataframe has the shift across columns and the time across rows, like this:

result = pd.DataFrame(data=[], index=data.index, columns=[0, 1, 2])

For example, the element of the result dataframe with index='2021-01-06' and column='1' would show the correlation between [0.5, 0.3, 0.1] and [0.2, 0.2, 0.5], which is -0.86603.

The desired output with this data would be:

data = pd.DataFrame(data=[[np.nan, np.nan, np.nan], [np.nan, np.nan, np.nan], [0.5291, np.nan, np.nan], [0.8358, -0.9484, np.nan], [0.8660, 0.7559, -0.9820], [0.9707, -0.8660, -0.9449], [0.7206, 0.5, -0.9608], ], index=['2021-01-01', '2021-01-02', '2021-01-03', '2021-01-04', '2021-01-05', '2021-01-06', '2021-01-07'], columns=[0, 1, 2])

This was my initial attempt, which failed. I was trying to combine rolling and apply with a custom function. There might be a better way.

corr_func = lambda x: x.iloc[:,1].corr(x.shift(1).iloc[:,0])
result = data.rolling(3).apply(corr_func)

Can anyone solve the challenge, please? In my real-life case I will be using a larger dataframe, a larger rolling window and many more shifts.


Solution

  • This should work. We're creating a new output for each shift and concatenating them.

    window = 3
    n_shifts = 3
    
    data = pd.DataFrame(data=[[0.1, 0.4], [0.6, 0.9], [0.5, 0.3], [0.3, 0.2], [0.1, 0.2], [0.9, 0.5], [0.5, 0.6]], index=['2021-01-01', '2021-01-02', '2021-01-03', '2021-01-04', '2021-01-05', '2021-01-06', '2021-01-07'], columns=['x', 'y'])
    
    results = pd.concat([data.x.shift(n).rolling(window).corr(data.y) for n in range(n_shifts)], axis=1)
    
    print(results)
    

    Output

                       0         1         2
    2021-01-01       NaN       NaN       NaN
    2021-01-02       NaN       NaN       NaN
    2021-01-03  0.529107       NaN       NaN
    2021-01-04  0.835766 -0.948421       NaN
    2021-01-05  0.866025  0.755929 -0.981981
    2021-01-06  0.970725 -0.866025 -0.944911
    2021-01-07  0.720577  0.500000 -0.960769