Search code examples
pythonpandasmatplotliblatexvisualization

Presenting complex table data in chart for a single slide


Tables allow to summarise complex information. I have a table similar following one (this is produce for this question) in my latex document, like so:

\documentclass{article}
\usepackage{graphicx} % Required for inserting images
\usepackage{tabularx}
\usepackage{booktabs}
\usepackage{makecell}

\begin{document}

\begin{table}[bt]
    \caption{Classification results.}
    \label{tab:baseline-clsf-reprt}
    \setlength{\tabcolsep}{1pt} % Adjust column spacing
    \renewcommand{\arraystretch}{1.2} % Adjust row height
    \begin{tabular}{lcccccccccccc}
        \toprule
        & \multicolumn{3}{c}{Data1} &
          \multicolumn{3}{c}{\makecell{Data2 \\ (original)}} &
          \multicolumn{3}{c}{\makecell{Data2 \\ (experiment 3)}} &
          \multicolumn{3}{c}{\makecell{Data2 \\ (experiment 4)}} \\
        \cmidrule(r{1ex}){2-4}
        \cmidrule(r{1ex}){5-7}
        \cmidrule(r{1ex}){8-10}
        \cmidrule(r{1ex}){11-13}
        & Precision & Recall & F1 & Precision & Recall & F1 & Precision & Recall & F1 & Precision & Recall & F1 \\
        \midrule
        Apple  & 0.61 & 0.91 & 0.71 & 0.61 & 0.72 & 0.91 & 0.83   & 0.62 & 0.71 & 0.62 & 0.54 & 0.87 \\
        
        Banana  & 0.90 & 0.32 & 0.36 & 0.86 & 0.81 & 0.53 & 0.61 & 0.69 & 0.68 & 0.72 & 0.56 & 0.57 \\
        
        Orange   & 0.23 & 0.35 & 0.18 & 0.56 & 0.56 & 0.56 & 0.54 & 0.55 & 0.55 & 0.55 & 0.57 & 0.63 \\
        
        Grapes   & 0.81 & 0.70 & 0.76 & 0.67 & 0.47 & 0.54 & 0.85 & 0.28 & 0.42 & 0.38 & 0.66 & 0.48 \\
        
        Mango & 0.31 & 0.23 & 0.45 & 0.87 & 0.54 & 0.73 & 0.63 & 0.57 & 0.63 & 0.75 & 0.29 & 0.34 \\
        \bottomrule
    \end{tabular}
\end{table}

\end{document}

Which gives: enter image description here

Now, I preparing a slide deck, and I needed to present the classification results in just one slide. To show results of each dataset for each fruit and metric.

My attempts didn't result in a chart that's meaning (showing all info in the table).

First attempt:

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

datasets = ['Data1', 'Data2-Orig', 'Data2-Exp3', 'Data2-Exp4']
fruits = ['Apple', 'Banana', 'Orange', 'Grapes', 'Mango']
metrics = ['Precision', 'Recall', 'F1']
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']  # Colors for Precision, Recall, F1

data = {
    'Fruit': ['Apple', 'Banana', 'Orange', 'Grapes', 'Mango'],
    'Data1_Precision': [0.61, 0.90, 0.23, 0.81, 0.31],
    'Data1_Recall': [0.91, 0.32, 0.35, 0.70, 0.23],
    'Data1_F1': [0.71, 0.36, 0.18, 0.76, 0.45],
    'Data2-Orig_Precision': [0.61, 0.86, 0.56, 0.67, 0.87],
    'Data2-Orig_Recall': [0.72, 0.81, 0.56, 0.47, 0.54],
    'Data2-Orig_F1': [0.91, 0.53, 0.56, 0.54, 0.73],
    'Data2-Exp3_Precision': [0.83, 0.61, 0.54, 0.85, 0.63],
    'Data2-Exp3_Recall': [0.62, 0.69, 0.55, 0.28, 0.57],
    'Data2-Exp3_F1': [0.71, 0.68, 0.55, 0.42, 0.63],
    'Data2-Exp4_Precision': [0.62, 0.72, 0.55, 0.38, 0.75],
    'Data2-Exp4_Recall': [0.54, 0.56, 0.57, 0.66, 0.29],
    'Data2-Exp4_F1': [0.87, 0.57, 0.63, 0.48, 0.34]
}

df = pd.DataFrame(data)

# Reshape data for Seaborn
df_melted = df.melt(id_vars='Fruit', 
                    var_name='Metric', 
                    value_name='Score')

# Split the 'Metric' column into separate columns for easier grouping
df_melted[['Dataset', 'Measure']] = df_melted['Metric'].str.split('_', expand=True)
df_melted.drop(columns='Metric', inplace=True)

plt.figure(figsize=(12, 8))
sns.set_style("whitegrid")

# Create grouped bar plot
sns.barplot(
    data=df_melted, 
    x='Fruit', 
    y='Score', 
    hue='Dataset', 
    ci=None
)

# Customize plot
plt.title('Classification Results by Fruit and Dataset')
plt.xlabel('Fruit type')
plt.ylabel('Score')
plt.legend(title='Dataset', bbox_to_anchor=(1.05, 1), loc='upper left')

# Show plot
plt.tight_layout()

Gives: enter image description here

Second attempt:

fig, ax = plt.subplots(figsize=(14, 8))

# Set the width of each bar and spacing between groups
bar_width = 0.2
group_spacing = 0.25
x = np.arange(len(fruits))

# Plot bars for each dataset and metric combination
for i, dataset in enumerate(datasets):
    for j, metric in enumerate(metrics):
        # Calculate the position for each bar within each group
        positions = x + i * (len(metrics) * bar_width + group_spacing) + j * bar_width
        # Plot each metric bar
        ax.bar(positions, 
               df[f'{dataset}_{metric}'], 
               width=bar_width, 
               label=f'{metric}' if i == 0 else "", 
               color=colors[j])

# Customize x-axis and labels
ax.set_xticks(x + (len(datasets) * len(metrics) * bar_width + (len(datasets) - 1) * group_spacing) / 2 - bar_width / 2)
ax.set_xticklabels(fruits)
ax.set_xlabel('Fruit type')
ax.set_ylabel('Score ')
ax.set_title('Classification Results by Dataset, Fruit, and Metric')

# Create custom legend for metrics
metric_legend = [plt.Line2D([0], [0], color=colors[i], lw=4) for i in range(len(metrics))]
ax.legend(metric_legend, metrics, title="Metrics", loc="upper left", bbox_to_anchor=(1.05, 1))

plt.tight_layout()
plt.show()

This gives: enter image description here

All these plots does not present the result in a way people can easily flow in a presentation. And adding the original table doesn't just make sense. People cannot easily flow the results in a table as I talk.

How would you recommend plotting the results in this table for adding to a slide?


Solution

  • I would definitely go for some kind of heatmap. Any barplot-like graphic would be cluttered.

    import pandas as pd
    import matplotlib.pyplot as plt
    import seaborn as sns
    
    data = {
        'Fruit': ['Apple', 'Banana', 'Orange', 'Grapes', 'Mango'],
        'Data1-Precision': [0.61, 0.90, 0.23, 0.81, 0.31],
        'Data1-Recall': [0.91, 0.32, 0.35, 0.70, 0.23],
        'Data1-F1': [0.71, 0.36, 0.18, 0.76, 0.45],
        'Data2-Orig-Precision': [0.61, 0.86, 0.56, 0.67, 0.87],
        'Data2-Orig-Recall': [0.72, 0.81, 0.56, 0.47, 0.54],
        'Data2-Orig-F1': [0.91, 0.53, 0.56, 0.54, 0.73],
        'Data2-Exp3-Precision': [0.83, 0.61, 0.54, 0.85, 0.63],
        'Data2-Exp3-Recall': [0.62, 0.69, 0.55, 0.28, 0.57],
        'Data2-Exp3-F1': [0.71, 0.68, 0.55, 0.42, 0.63],
        'Data2-Exp4-Precision': [0.62, 0.72, 0.55, 0.38, 0.75],
        'Data2-Exp4-Recall': [0.54, 0.56, 0.57, 0.66, 0.29],
        'Data2-Exp4-F1': [0.87, 0.57, 0.63, 0.48, 0.34]
    }
    
    df = pd.DataFrame(data)
    
    df_melted = df.melt(id_vars='Fruit', var_name='Dataset-Metric', value_name='Score')
    df_melted[['Dataset', 'Metric']] = df_melted['Dataset-Metric'].str.extract(r'(.+)-(.+)')
    heatmap_data = df_melted.pivot(index='Fruit', columns=['Dataset', 'Metric'], values='Score')
    
    plt.figure(figsize=(14, 8))
    sns.heatmap(
        heatmap_data,
        annot=True,
        fmt=".2f",
        cmap="YlGnBu",
        linewidths=0.5,
        cbar_kws={'label': 'Score'}
    )
    plt.title('Classification Results Heatmap')
    plt.xlabel('Dataset and Metric')
    plt.ylabel('Fruit')
    plt.tight_layout()
    plt.show()
    
    
    

    which gives

    enter image description here

    But if you absolutely want to stick to barplots, choose to do it in 3d:

    import pandas as pd
    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    
    data = {
        'Fruit': ['Apple', 'Banana', 'Orange', 'Grapes', 'Mango'],
        'Data1-Precision': [0.61, 0.90, 0.23, 0.81, 0.31],
        'Data1-Recall': [0.91, 0.32, 0.35, 0.70, 0.23],
        'Data1-F1': [0.71, 0.36, 0.18, 0.76, 0.45],
        'Data2-Orig-Precision': [0.61, 0.86, 0.56, 0.67, 0.87],
        'Data2-Orig-Recall': [0.72, 0.81, 0.56, 0.47, 0.54],
        'Data2-Orig-F1': [0.91, 0.53, 0.56, 0.54, 0.73],
        'Data2-Exp3-Precision': [0.83, 0.61, 0.54, 0.85, 0.63],
        'Data2-Exp3-Recall': [0.62, 0.69, 0.55, 0.28, 0.57],
        'Data2-Exp3-F1': [0.71, 0.68, 0.55, 0.42, 0.63],
        'Data2-Exp4-Precision': [0.62, 0.72, 0.55, 0.38, 0.75],
        'Data2-Exp4-Recall': [0.54, 0.56, 0.57, 0.66, 0.29],
        'Data2-Exp4-F1': [0.87, 0.57, 0.63, 0.48, 0.34]
    }
    
    df = pd.DataFrame(data)
    
    df_melted = df.melt(id_vars='Fruit', var_name='Dataset-Metric', value_name='Score')
    df_melted[['Dataset', 'Metric']] = df_melted['Dataset-Metric'].str.extract(r'(.+)-(.+)')
    
    fruits = df_melted['Fruit'].unique()
    datasets = df_melted['Dataset'].unique()
    metrics = df_melted['Metric'].unique()
    
    x = np.array([np.where(fruits == fruit)[0][0] for fruit in df_melted['Fruit']])
    y = np.array([np.where(datasets == dataset)[0][0] for dataset in df_melted['Dataset']])
    z = np.array([np.where(metrics == metric)[0][0] for metric in df_melted['Metric']])
    
    scores = df_melted['Score'].values
    
    fig = plt.figure(figsize=(12, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    dx = dy = 0.4 
    dz = scores 
    
    colors = plt.cm.viridis(scores / max(scores))  
    ax.bar3d(x, y, np.zeros_like(z), dx, dy, dz, color=colors, alpha=0.8)
    
    ax.set_xlabel('Fruit')
    ax.set_ylabel('Dataset')
    ax.set_zlabel('Score')
    
    ax.set_xticks(range(len(fruits)))
    ax.set_xticklabels(fruits, rotation=45)
    ax.set_yticks(range(len(datasets)))
    ax.set_yticklabels(datasets)
    ax.set_zticks(np.linspace(0, 1, 6))
    
    plt.title('3D Bar Plot of Classification Results')
    plt.tight_layout()
    plt.show()
    

    which gives

    enter image description here

    BUT, I still think a heatmap is more readable.