Search code examples
pandasnumpyvectorization

(Pandas) Is there a faster way to find the indexes of all values in a row which are equal to the max in that row in pandas?


I have a pandas dataframe Cellvoltage with multiple columns of float values. I'm trying to get the indices of all the values that equal the max value for every row.

So to achieve that, I am using the following code:

req_indices = np.argwhere(Cellvoltage.values == np.amax(Cellvoltage.values, axis=1).reshape(-1,1))
max_voltage_idx = [0]*len(req_indices)
for x,y in req_indices:
    if max_voltage_idx[x] == 0:
        max_voltage_idx[x] = [y]
    else:
        max_voltage_idx[x].append(y)
Cellvoltage['max_voltage_idx']  = pd.Series(max_voltage_idx).apply(np.array)

Is there a better/faster way to achive the same goal?


Solution

  • Use df.eq(df.max(axis=1), axis=0) with where to mask the non-max values, then stack to remove the NaNs and get the index pairs.

    out = df.where(df.eq(df.max(axis=1), axis=0)).stack().index.tolist()
    

    Example input:

    
       A  B  C  D
    0  1  2  3  3
    1  4  1  1  4
    

    Output:

    [(0, 'C'), (0, 'D'), (1, 'A'), (1, 'D')]
    

    If you want only the columns in a list as new column add a groupby.agg step:

    df['max_voltage_idx']= (df                  
       .where(df.eq(df.max(axis=1), axis=0)).stack()
       .reset_index(-1).iloc[:, 0]
       .groupby(level=0).agg(list)
     )
    

    Output:

       A  B  C  D max_voltage_idx
    0  1  2  3  3          [C, D]
    1  4  1  1  4          [A, D]