Search code examples
pythonseabornheatmap

Curve on top of heatmap Seaborn


I'm trying to reproduce this graph:

enter image description here

Here's my code:

#trials_per_sim_list = np.logspace(1, 6, 1000).astype(int) 
trials_per_sim_list = np.logspace(1, 5, 100).astype(int)
trials_per_sim_list.sort()
    
sharpe_ratio_theoretical = pd.Series({num_trials:get_expected_max_SR(num_trials, mean_SR = 0, std_SR = 1) 
                                          for num_trials in trials_per_sim_list})
sharpe_ratio_theoretical = pd.DataFrame(sharpe_ratio_theoretical, columns = ['max{SR}'])
sharpe_ratio_theoretical.index.names = ['num_trials']
    
sharpe_ratio_sims = get_max_SR_distribution(
                                    #num_sims = 1e3,
                                    num_sims = 100,
                                    trials_per_sim_list = trials_per_sim_list, 
                                    mean_SR = 0.0, 
                                    std_SR = 1.0)
    
    
heatmap_df = sharpe_ratio_sims.copy()
    
heatmap_df['count'] = 1
heatmap_df['max{SR}'] = heatmap_df['max{SR}'].round(3)
heatmap_df = heatmap_df.groupby(['num_trials', 'max{SR}']).count().reset_index()
heatmap_df = heatmap_df.pivot(index = 'max{SR}', columns = 'num_trials', 
                                  values = 'count')
heatmap_df = heatmap_df.fillna(0)
    
heatmap_df = heatmap_df.sort_index(ascending = False)
    
fig, ax = plt.subplots()
    
sns.heatmap(heatmap_df, cmap = 'Blues', ax = ax)
sns.lineplot(x = sharpe_ratio_theoretical.index, 
                 y = sharpe_ratio_theoretical['max{SR}'], 
                 linestyle = 'dashed', ax = ax)
    
plt.show()

I think the issue is that the heatmap is plotting on a log-scale because I've inputted a log-scale, while my lineplot isn't mapping onto the save values. My result so far is this:

enter image description here

If you would like to see the code I'm using for the functions please go here: https://github.com/charlesrambo/ML_for_asset_managers/blob/main/Chapter_8.py

Edit: No response so far. If it's the quant finance part that's confusing, here's a more straight forward example. I would like to add the graph of y = 1/sqrt{x} to my plot. Here's the code:

import numpy as np
import pandas as pd
import scipy.stats as stats
import matplotlib.pyplot as plt

vals = np.logspace(0.5, 3.5, 100).astype(int)

theoretical_values = pd.Series(1/np.sqrt(vals), index = vals)

num_runs = 10000
trials_per_run = 10

exprimental_values = np.zeros(shape = (num_runs * len(vals), 2))

for i, n in enumerate(vals):
    
    for j in range(num_runs):
        
        dist = stats.norm.rvs(size = (trials_per_run, n)).mean(axis = 1)
        
        exprimental_values[num_runs * i + j, 0] = n
        exprimental_values[num_runs * i + j, 1] = np.std(dist, ddof = 1)
            
exprimental_values = pd.DataFrame(exprimental_values, columns = ['num', 'std'])
      
heatmap_df = exprimental_values.copy()
heatmap_df['count'] = 1
heatmap_df['std'] = heatmap_df['std'].round(3)
    
heatmap_df = heatmap_df.groupby(['num', 'std'])['count'].sum().reset_index()

heatmap_df = heatmap_df.pivot(index = 'std', columns = 'num', values = 'count') 

heatmap_df = heatmap_df.fillna(0)
       
heatmap_df = heatmap_df.div(heatmap_df.sum(axis = 0), axis = 1)

heatmap_df = heatmap_df.sort_index(ascending = False)

fig, ax = plt.subplots()

sns.heatmap(heatmap_df, cmap = 'Blues', ax = ax)

sns.lineplot(x = theoretical_values.index, y = theoretical_values, ax = ax)

plt.show()

I'm getting this: enter image description here


Solution

  • I replicate the visual effect below, using matplotlib:

    enter image description here

    To speed things up I put values into a list rather than into a growing dataframe. To speed things up further, you could do one or a combination of:

    • Compute the inner loop values in a single shot by adding an extra dimension to dist (i.e. this will dispense with the need for an inner loop).
    • Pre-allocate a numpy array rather than concatenating to a list
    • Use numba for easy parallelisation of loops

    I do the pandas manipulations in a chained manner.

    Reproducible example

    Sample data (modified to use list concatenation rather than DataFrame concatenation):

    import numpy as np
    import pandas as pd
    import scipy.stats as stats
    import matplotlib.pyplot as plt
    
    #
    # Data for testing (using fewer vals compared to OP)
    #
    vals = np.logspace(1, 6 - 2, 100).astype(int)
    
    theoretical_values = pd.Series(1 / vals**0.5, index=vals)
    
    num_runs = 1000
    trials_per_run = 10
    
    #This list keeps track of (num, std) values
    experimental_values_list = []
    for idx, val in enumerate(vals):
        #Report progress (nb. the `tqdm` package gives you nicer progress bars)
        print(
            f'val {idx + 1} of {len(vals)} ({(idx + 1) / len(vals):.1%}) [val={val:,}]',
            ' ' * 100, end='\r'
        )
    
        for _ in range(num_runs):
            dist = stats.norm.rvs(size=(trials_per_run, val)).mean(axis=1)
            
            experimental_values_list.append(
                (val, dist.std(ddof=1))
            )
    
    #List of [(num, std), ...] to DataFrame 
    experimental_values_df = pd.DataFrame(
        experimental_values_list,
        columns=['num', 'std'],
    )
    
    #View dataframe
    experimental_values_df
    

    Manipulate data into heatmap_df which is formatted as std-by-val:

    #
    # Manipulate into an std-by-num table, akin to a heatmap
    #
    heatmap_df = (
        experimental_values_df
    
        #Round std to 3dp
        .assign(std_round=lambda df_: df_['std'].round(3))
        
        #Count occurences of (num, std_round) pairs, and
        # unstack (pivot) num to the columns axis
        .value_counts(['std_round', 'num']).unstack().fillna(0)
        
        # #sort by std_round large-small
        # Could equivalently remove this line and append .iloc[::-1] to line above
        .sort_index(ascending=False)
        
        # #Divide each row by the row's sum
        .pipe(lambda df_: df_.div(df_.sum(axis=1), axis='index'))
    )
    display(heatmap_df)
    
    #Optionally display a styled df for a quick visualisation
    (
        #Subsample for brevity
        heatmap_df.iloc[::30, ::7]
    
        #Style for coloured bolded text
        .style
        .format(precision=4)
        .text_gradient(cmap='magma')
        .set_table_styles([{'selector': 'td', 'props': 'font-weight: bold'}])
        .set_caption('subsampled table coloured by value')
    )
    

    Intermediate output is a styled dataframe for quick visualisation:

    enter image description here

    Heatmap using matplotlib, with a curve overlay:

    #
    # Heatmap and curve
    #
    
    f, ax = plt.subplots(figsize=(8, 3), layout='tight')
    
    mappable = ax.pcolormesh(
        heatmap_df.columns, heatmap_df.index, heatmap_df,
        cmap='inferno', vmax=0.07,
    )
    ax.set_ylim(heatmap_df.index.min(), 0.1)
    ax.set_xscale('log')
    
    #Title and axis labels
    ax.set_title('Dispersion across strategies', fontweight='bold')
    ax.set(xlabel='Number of trials', ylabel='standard deviation')
    
    #Add colorbar (and optionally adjust its positioning)
    f.colorbar(mappable, label='normalized counts', aspect=7, pad=0.02)
    
    #Overlay a curve through the mean
    curve = experimental_values_df.groupby('num')['std'].mean().rename('mean std')
    curve.plot(ax=ax, color='black', lw=2, legend=True)