Search code examples
pythonpandasdataframemulti-index

Highlight max/min on MultiIndex DataFrame - Pandas


Assume to have a 2 layers MultiIndex dataframe:

df = pd.DataFrame([['one', 'A', 100,3], ['two', 'A', 101, 4], 
                   ['three', 'A', 102, 6], ['one', 'B', 103, 6], 
                   ['two', 'B', 104, 0], ['three', 'B', 105, 3]],
   columns=['c1', 'c2', 'c3', 'c4']).set_index(['c1', 'c2']).sort_index()
print(df)

which looks like this one

           c3  c4
c1    c2         
one   A   100   3
      B   103   6
three A   102   6
      B   105   3
two   A   101   4
      B   104   0

My goal is to highlight (with Pandas's styling) the minimum (or equivalently the maximum) between the elements of 'c2' for all the columns 'c3' and 'c4' for each element in 'c1'

             c3      c4
c1    c2         
one   A   **100**   **3**
      B     103       6
three A   **102**     6
      B     105     **3**
two   A   **101**     4
      B     104     **0**

Do you have any suggestion?

I already tried this one, but it works column-wise and not based on the index.

def highlight_min(data):

    attr = 'background-color: {}'.format(color)

    if data.ndim == 1:  # Series from .apply(axis=0) or axis=1
        is_max = data == data.min()
        return [attr if v else '' for v in is_max]
    else:  # from .apply(axis=None)
        is_max = data == data.min().min()
        return pd.DataFrame(np.where(is_max, attr, ''),
                            index=data.index, columns=data.columns)

df = df.style.apply(highlight_min, axis=0)

The results if the following

             c3      c4
c1    c2         
one   A   **100**     3
      B     103       6
three A     102       6
      B     105       3
two   A     101       4
      B     104     **0**

Solution

  • Use GroupBy.transform with min and compare by all values:

    def highlight_min(data):
        color= 'red'
        attr = 'background-color: {}'.format(color)
    
        if data.ndim == 1:  # Series from .apply(axis=0) or axis=1
            is_min = data == data.min()
            return [attr if v else '' for v in is_min]
        else: 
            is_min = data.groupby(level=0).transform('min') == data
            return pd.DataFrame(np.where(is_min, attr, ''),
                                index=data.index, columns=data.columns)