Search code examples
python-3.xpandasseabornheatmapxticks

Custom xticks in seaborn heatmap


I have the following heatmap (just a minimum working example, my data is huge!)

df = pd.DataFrame({'set1': ['0', '2', '2'],
                   'set2': ['1', '2', '0'],
                   'set3': ['0', '2', '1'],
                   'set4': ['1', '4', '1']
                  }).T.astype(float)

sns.heatmap(df, yticklabels = df.index, xticklabels = df.columns) 

How can I put the xticks only of those where all the rows>=2? In this example it means putting only the '1' xtick only

So in this image '0' and '2' column names should not appear, only '1' because this is the column where all values are greater or equal than 2.

Problem is the the x axis gets too crowded. I want to show only the xticklabels of the columns where all values are >= 2. Still plotting everything but showing only those xticklabels.

enter image description here


Solution

  • Mask the DataFrame

    • Removes columns where the values are not >= to the specified value
    # create a Boolean mask of df
    mask = df.ge(2)
    
    # apply the mask to df and dropna
    dfm = df[mask].dropna(axis=1)
    
    # plot the masked df
    ax = sns.heatmap(dfm)
    

    enter image description here

    mask

              0     1      2
    set1  False  True   True
    set2  False  True  False
    set3  False  True  False
    set4  False  True  False
    

    Mask the xtick labels

    • Labels to columns where the values are not >= to the specified value are replaced with ''.
    # create a Boolean mask of df
    mask = df.ge(2).all()
    
    # use the mask to update a list of labels
    cols = [col if m else '' for (col, m) in zip(df.columns, mask)]
    
    # plot with custom labels
    ax = sns.heatmap(df, xticklabels=cols)
    

    enter image description here

    mask

    0    False
    1     True
    2    False
    dtype: bool