Search code examples
pythonmatplotlibheatmapsparklines

Creating a categorical heatmap with sparklines?


Does anyone know of an example of how to create a categorical heat map with individual sparklines within each cell? Or have a suggestion on how to use matplotlib's annotation to produce this (or something similar)?

Essentially turning this: Matplotlib heatmap annotation

into this: Heatmap with sparkline

enter image description here


Solution

  • Assuming such a format as input (an arbitrary number of rows for each combination of row/col) and that we want to plot a heatmap with the average value per row/col, and a small line for each row/col combination with the consecutive values:

        row col     value
    0     A   a -2.911793
    1     A   a -3.066935
    2     A   a -0.940881
    3     A   a  1.838795
    4     A   a  2.359492
    ..   ..  ..       ...
    595   E   f -3.233857
    596   E   f -4.348279
    597   E   f -4.236598
    598   E   f -4.697110
    599   E   f -3.618638
    
    [600 rows x 3 columns]
    

    You could plot a heatmap using sns.heatmap on the reshaped data (with pivot_table, here using the mean of the data per group), then rework the data to plot a line on top of it:

    import seaborn as sns
    
    ax = sns.heatmap(df.pivot_table(index='row', columns='col',
                                    values='value', aggfunc='mean'))
    
    margin = 0.1
    
    def norm(s, margin=0):
        '''Normalizes the input Series between 0+margin and 1-margin'''
        MIN = s.min()
        return (s-MIN)/(s.max()-MIN)*(1-2*margin)+margin
    
    tmp = (df
           .sort_values(by=['row', 'col']) # ensure data is sorted
           # compute index/col position per group to match the heatmap
           .assign(row_id=lambda d: pd.factorize(d['row'])[0],
                   col_id=lambda d: pd.factorize(d['col'])[0],
                   # deduplicate the data to form a x-value and shift per col
                   x=lambda d: (x:=d.groupby(['row_id', 'col_id']
                               ).cumcount())/x.max()+d['col_id'],
                   # normalize the data and shift per row
                   norm_value=lambda d: (norm(d['value'], margin=0.1).rsub(1)
                                         + d['row_id']
                                        ).mask(d['col_id'].ne(d['col_id'].shift())),
                   )
          )
    
    tmp.plot(x='x', y='norm_value', ax=ax, legend=False)
    

    Example output:

    heatmap + lineplot for each cell

    Reproducible input:

    import numpy as np
    import pandas as pd
    from string import ascii_uppercase, ascii_lowercase
    
    R, C, N = 5, 6, 20
    
    np.random.seed(0)
    a = np.arange(R*C*N)
    df = (pd.DataFrame({'row': np.array(list(ascii_uppercase))[a//(C*N)],
                        'col': np.array(list(ascii_lowercase))[a%(C*N)//N],
                        'value': 10*np.sin(a/N*5)+np.random.normal(scale=2, size=R*C*N),
                       })
            .assign(value=lambda d: d.groupby(['row', 'col'])['value']
                    .transform(lambda s: s*np.random.uniform(0.1, 1)+np.random.uniform(-10, 10)))
            #.sample(frac=0.7).sort_index()
         )
    

    Alternative output when .sample(frac=0.7).sort_index() is uncommented (to simulate uneven groups):

    heatmap + lineplot for each cell