Search code examples
pythonpandasdataframerolling-computationrolling-sum

Conditional mean and sum of previous N rows in pandas dataframe


Concerned is this exemplary pandas dataframe:

      Measurement  Trigger  Valid
   0          2.0    False   True
   1          4.0    False   True
   2          3.0    False   True
   3          0.0     True  False
   4        100.0    False   True
   5          3.0    False   True
   6          2.0    False   True
   7          1.0     True   True

Whenever Trigger is True, I wish to calculate sum and mean of the last 3 (starting from current) valid measurements. Measurements are considered valid, if the column Valid is True. So let's clarify using the two examples in the above dataframe:

  1. Index 3: Indices 2,1,0 should be used. Expected Sum = 9.0, Mean = 3.0
  2. Index 7: Indices 7,6,5 should be used. Expected Sum = 6.0, Mean = 2.0

I have tried pandas.rolling and creating new, shifted columns, but was not successful. See the following excerpt from my tests (which should directly run):

import unittest
import pandas as pd
import numpy as np
from pandas.util.testing import assert_series_equal

def create_sample_dataframe_2():
    df = pd.DataFrame(
        {"Measurement" : [2.0,   4.0,   3.0,   0.0,   100.0, 3.0,   2.0,   1.0 ],
         "Valid"       : [True,  True,  True,  False, True,  True,  True,  True],
         "Trigger"     : [False, False, False, True,  False, False, False, True],
         })
    return df

def expected_result():
    return pd.DataFrame({"Sum" : [np.nan, np.nan, np.nan, 9.0, np.nan, np.nan, np.nan, 6.0],
                         "Mean" :[np.nan, np.nan, np.nan, 3.0, np.nan, np.nan, np.nan, 2.0]})

class Data_Preparation_Functions(unittest.TestCase):

    def test_backsummation(self):
        N_SUMMANDS = 3
        temp_vars = []

        df = create_sample_dataframe_2()
        for i in range(0,N_SUMMANDS):
            temp_var = "M_{0}".format(i)
            df[temp_var] = df["Measurement"].shift(i)
            temp_vars.append(temp_var)

        df["Sum"]  = df[temp_vars].sum(axis=1)
        df["Mean"] = df[temp_vars].mean(axis=1)
        df.loc[(df["Trigger"]==False), "Sum"] = np.nan
        df.loc[(df["Trigger"]==False), "Mean"] = np.nan

        assert_series_equal(expected_result()["Sum"],df["Sum"])
        assert_series_equal(expected_result()["Mean"],df["Mean"])

    def test_rolling(self):
        df = create_sample_dataframe_2()
        df["Sum"]  = df[(df["Valid"] == True)]["Measurement"].rolling(window=3).sum()
        df["Mean"] = df[(df["Valid"] == True)]["Measurement"].rolling(window=3).mean()

        df.loc[(df["Trigger"]==False), "Sum"] = np.nan
        df.loc[(df["Trigger"]==False), "Mean"] = np.nan
        assert_series_equal(expected_result()["Sum"],df["Sum"])
        assert_series_equal(expected_result()["Mean"],df["Mean"])


if __name__ == '__main__':
    suite = unittest.TestLoader().loadTestsFromTestCase(Data_Preparation_Functions)
    unittest.TextTestRunner(verbosity=2).run(suite)

Any help or solution is greatly appreciated. Thanks and Cheers!

EDIT: Clarification: This is the resulting dataframe I expect:

      Measurement  Trigger  Valid   Sum   Mean
   0          2.0    False   True   NaN    NaN
   1          4.0    False   True   NaN    NaN
   2          3.0    False   True   NaN    NaN
   3          0.0     True  False   9.0    3.0
   4        100.0    False   True   NaN    NaN
   5          3.0    False   True   NaN    NaN
   6          2.0    False   True   NaN    NaN
   7          1.0     True   True   6.0    2.0

EDIT2: Another clarification:

I did indeed not miscalculate, but rather I did not make my intentions as clear as I could have. Here's another try using the same dataframe:

Desired dataframe, relevant fields highlighted

Let's first look at the Trigger column: We find the first True in index 3 (green rectangle). So index 3 is the point, where we start looking. There is no valid measurement at index 3 (Column Valid is False; red rectangle). So, we start to go further back in time, until we have accumulated three lines, where Valid is True. This happens for indices 2,1 and 0. For these three indices, we calculate the sum and mean of the column Measurement (blue rectangle):

  • SUM: 2.0 + 4.0 + 3.0 = 9.0
  • MEAN: (2.0 + 4.0 + 3.0) / 3 = 3.0

Now we start the next iteration of this little algorithm: Look again for the next True in the Trigger column. We find it at index 7 (green rectangle). There is also a valid measuremnt at index 7, so we include it this time. For our calculation, we use indices 7,6 and 5 (green rectangle), and thus get:

  • SUM: 1.0 + 2.0 + 3.0 = 6.0
  • MEAN: (1.0 + 2.0 + 3.0) / 3 = 2.0

I hope, this sheds more light on this little problem.


Solution

  • Heres an option, take the 3 period rolling mean and sum

    df['RollM'] = df.Measurement.rolling(window=3,min_periods=0).mean()
    
    df['RollS'] = df.Measurement.rolling(window=3,min_periods=0).sum()
    

    Now set False Triggers equals to NaN

    df.loc[df.Trigger == False,['RollS','RollM']] = np.nan
    

    yields

       Measurement  Trigger  Valid     RollM  RollS
    0          2.0    False   True       NaN    NaN
    1          4.0    False   True       NaN    NaN
    2          3.0    False   True       NaN    NaN
    3          0.0     True  False  2.333333    7.0
    4        100.0    False   True       NaN    NaN
    5          3.0    False   True       NaN    NaN
    6          2.0    False   True       NaN    NaN
    7          1.0     True   True  2.000000    6.0
    

    Edit, updated to reflect valid argument

    df['mean'],df['sum'] = np.nan,np.nan
    
    roller = df.Measurement.rolling(window=3,min_periods=0).agg(['mean','sum'])
    
    df.loc[(df.Trigger == True) & (df.Valid == True),['mean','sum']] = roller
    
    df.loc[(df.Trigger == True) & (df.Valid == False),['mean','sum']] = roller.shift(1)
    

    Yields

      Measurement  Trigger  Valid  mean  sum
    0          2.0    False   True   NaN  NaN
    1          4.0    False   True   NaN  NaN
    2          3.0    False   True   NaN  NaN
    3          0.0     True  False   3.0  9.0
    4        100.0    False   True   NaN  NaN
    5          3.0    False   True   NaN  NaN
    6          2.0    False   True   NaN  NaN
    7          1.0     True   True   2.0  6.0