Search code examples
pythonpandasseabornimshow

How to refine heatmap?


I have a pandas data frame the looks like this:

         SPX   RYH  RSP   RCD   RYE  ...   RTM   RHS   RYT   RYU  EWRE
Date                                     ...                              
2022-02-25   NaN   NaN  NaN   NaN   NaN  ...   NaN   NaN   NaN   NaN   NaN
2022-03-04   9.0   5.0  8.0  12.0   1.0  ...   6.0   4.0  11.0   2.0   3.0 
2022-03-11   8.0  12.0  6.0  11.0   1.0  ...   3.0  13.0   9.0   2.0   4.0
2022-03-18   5.0   6.0  8.0   1.0  13.0  ...   9.0  10.0   2.0  12.0  11.0
2022-03-25   5.0  12.0  9.0  13.0   1.0  ...   2.0   4.0  10.0   3.0   7.0

Here is the info on it:

>>> a.ranks.info()
<class 'pandas.core.frame.DataFrame'>
DatetimeIndex: 52 entries, 2022-02-25 to 2023-02-17
Data columns (total 13 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   SPX     51 non-null     float64
 1   RYH     51 non-null     float64
 2   RSP     51 non-null     float64
 3   RCD     51 non-null     float64
 4   RYE     51 non-null     float64
 5   RYF     51 non-null     float64
 6   RGI     51 non-null     float64
 7   EWCO    51 non-null     float64
 8   RTM     51 non-null     float64
 9   RHS     51 non-null     float64
 10  RYT     51 non-null     float64
 11  RYU     51 non-null     float64
 12  EWRE    51 non-null     float64
dtypes: float64(13)
memory usage: 5.7 KB
>>> 

I plot a heatmap of it like so:

    cmap = sns.diverging_palette(133, 10, as_cmap=True)
    sns.heatmap(self.ranks, cmap=cmap, annot=True, cbar=False)
    plt.show()

This is the result:enter image description here What I would like to have is the image flipped with symbols on the y-axis and dates on the x-axis. I have tried .imshow() and the various pivot methods to no avail.

I suspect that have two questions: Is seaborn or imshow the right way to go about this? How do I pivot a pandas dataframe where the index is datetime?


Solution

  • You can flip the x and y by transposing the dataframe (df.T, interchanging index and columns). As the default datetime conversion also adds the time, the dates need to be converted manually to strings. sns.heatmap has parameters to explicitly set or change the tick labels. Optionally, you can drop the all-NaN rows.

    import matplotlib.pyplot as plt
    import seaborn as sns
    import pandas as pd
    import numpy as np
    
    # first create some test data similar to the given data
    rank_df = pd.DataFrame(
        np.random.randint(1, 15, size=(52, 13)).astype(float),
        columns=['SPX', 'RYH', 'RSP', 'RCD', 'RYE', 'RYF', 'RGI', 'EWCO', 'RTM', 'RHS', 'RYT', 'RYU', 'EWRE'],
        index=pd.date_range('2022-02-25', '2023-02-17', freq='W-FRI'))
    rank_df.iloc[0, :] = np.nan
    
    rank_df_transposed = rank_df.dropna(how='all').T
    xticklabels = [t.strftime('%Y-%m-%d') for t in rank_df_transposed.columns]
    # optionally remove repeating months
    xticklabels = [t1[8:] + ('\n' + t1[:7] if t1[:7] != t0[:7] else '')
                   for t0, t1 in zip([' ' * 10] + xticklabels[:-1], xticklabels)]
    
    fig, ax = plt.subplots(figsize=(15, 7))
    sns.heatmap(data=rank_df_transposed,
                xticklabels=xticklabels, yticklabels=True,
                annot=True, cbar=False, ax=ax)
    ax.tick_params(axis='x', rotation=0)
    ax.tick_params(axis='y', rotation=0)
    plt.tight_layout()  # fit all the labels nicely into the surrounding figure
    plt.show()
    

    sns.heatmap interchanging x and y