Search code examples
pythonpandasnandata-analysisdata-cleaning

groupby with multiple fillna strategies at once (pandas)


How can you group by one column, and then inside each group apply multiple fillna strategies at once on the other columns? Multiple meaning:

  1. if first in group, replace by zero, then ffill until first datapoint is reached
  2. trailing NaN's are ffilled
  3. for all NaN's between datapoints, bfill
  4. if it's all-NaN, leave the group alone

Basically, I have the following dataframe:

    A    B     C
0   A  NaN   NaN
1   A  NaN   NaN
2   A  1.0  10.0
3   A  NaN   NaN
4   B  NaN   NaN
5   B  2.0  20.0
6   B  NaN  20.0
7   B  NaN   NaN
8   C  NaN   NaN
9   C  NaN   NaN
10  C  NaN   NaN
11  C  NaN  30.0

And I'd like it to turn into:

    A    B     C
0   A    0     0
1   A    0     0
2   A  1.0  10.0
3   A  1.0  10.0
4   B    0     0
5   B  2.0  20.0
6   B  2.0  20.0
7   B  2.0  20.0
8   C  NaN     0
9   C  NaN     0
10  C  NaN     0
11  C  NaN  30.0

I've tried getting the first element with df.groupby('A').nth(1) and to continue conditionally but the new index created by the groupby is not the original one (i.e. 0,4,8), regardless whether I pass the .reset_index() option or not.

Code for dataframe recreation:


df = pd.DataFrame({'A' : ["A", "A", "A", "A",
                          "B", "B", "B", "B","C", "C", "C", "C"],
                   'B' : [np.nan, np.nan, 1, np.nan,
                          np.nan, 2, np.nan, np.nan,
                          np.nan, np.nan, np.nan, np.nan],
                   'C' : [np.nan, np.nan, 10, np.nan,
                          np.nan, 20, 20, np.nan,
                          np.nan, np.nan, np.nan, 30]})

Solution

  • One possible idea is using DataFrame.groupby on A then using a custom transformer function:

    def fx(s):
        if s.isna().all():
            return s
        elif pd.isna(s.iloc[0]):
            s.iloc[0] = 0
            s = s.ffill().bfill()
        return s
    
    df[['B', 'C']] = df.groupby('A')[['B', 'C']].transform(fx)
    

    # print(df)
        A    B     C
    0   A  0.0   0.0
    1   A  0.0   0.0
    2   A  1.0  10.0
    3   A  1.0  10.0
    4   B  0.0   0.0
    5   B  2.0  20.0
    6   B  2.0  20.0
    7   B  2.0  20.0
    8   C  NaN   0.0
    9   C  NaN   0.0
    10  C  NaN   0.0
    11  C  NaN  30.0