Search code examples
pythonpandasgroup-bygrouping

How to compute values for each row (with a unique function) when grouping a data frame in python?


I have this pandas data frame:


df = pd.DataFrame(
         {"CALDT": ["1980-01-31", "1980-02-28", "1980-03-31",
                    "1980-01-31", "1980-02-28", "1980-03-31",
                    "1980-01-31"],
          "ID": [1, 1, 1, 
                 2, 2, 2,
                 3],
          "Return": [0.02, 0.05, 0.10,
                     0.05, -0.02, 0.03,
                     -0.03]
          })

df['Year'] = pd.to_datetime(df['CALDT']).dt.year

My goal is: If the ID was alive >= 2 months, for each ID and year (grouping), based on Return, compute the mean and median multiplied by 12, and assign this value back to the row.

The expected output should look as follows:

df_new = pd.DataFrame(
    {"CALDT": ["1980-01-31", "1980-02-28", "1980-03-31",
               "1980-01-31", "1980-02-28", "1980-03-31",
               "1980-01-31"],
     "Year": [1980, 1980, 1980,
              1980, 1980, 1980,
              1980],
     "ID": [1, 1, 1, 
            2, 2, 2,
            3],
     "Return": [0.02, 0.05, 0.10,
                0.05, -0.02, 0.03,
                0.03],
     "Mean_Return": [0.68, 0.68, 0.68,
                     0.24, 0.24, 0.24,
                     np.nan],
     "Median_Return": [0.60, 0.60, 0.60,
                       0.36, 0.36, 0.36,
                       np.nan]
     })


In R, this is quite easy by using group_by from tidyverse:

df = df %>% 
  mutate(Year = year(CALDT)) %>% 
  group_by(CRSP_FUNDNO, Year) %>% 
  mutate(months_alive = length(unique(CALDT))) %>% 
  mutate(mean = case_when(months_alive >= 2 ~ mean(Return)*3,
                          .default = NA)) %>% 
  mutate(median = case_when(months_alive >= 2 ~ mean(Return)*12,
                            .default = NA))

Any help would be appreciated!


Solution

  • g = df.groupby(["ID", df.CALDT.dt.year])
    return_stats = pd.DataFrame({
                         "Mean_Return": g["Return"].transform("mean").mul(12),
                         "Median_Return": g["Return"].transform("median").mul(12)
                      }).where(g["CALDT"].transform("nunique").ge(2))
    
    df.join(return_stats)
    
    • group by the "ID" and the year (we don't have to make a new column to group by)
    • calculate desired statistics over "Return", i.e., mean & median, by transform (instead of, e.g., agg) so that it "broadcasts", i.e., repeats the computed values to be the same size as groups instead of single scalar aggregation
    • only retain the stats where the number of unique "CALDT" entries per group are greater than or equal to 2
    • join this with the original frame

    to get

           CALDT  ID  Return  Mean_Return  Median_Return
    0 1980-01-31   1    0.02         0.68           0.60
    1 1980-02-28   1    0.05         0.68           0.60
    2 1980-03-31   1    0.10         0.68           0.60
    3 1980-01-31   2    0.05         0.24           0.36
    4 1980-02-28   2   -0.02         0.24           0.36
    5 1980-03-31   2    0.03         0.24           0.36
    6 1980-01-31   3   -0.03          NaN            NaN