Search code examples
pythonpandasmatplotlibscatter-plotboxplot

matplotlib boxplot doesn't align with overlaid scatterplot


I have a plot where I'm trying to overlay a scatter series on a boxplot series... here is a simple example of the problem so that you can re-create it.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

names = ['a','b','c','d','e','f']

df = pd.DataFrame(np.random.rand(6,6), columns=names)
display(df)

plt.boxplot(df, labels=names)
plt.show()

plt.scatter(names, df.head(1))
plt.show()

plt.boxplot(df, labels=names)
plt.scatter(names, df.head(1))
plt.show()

Results:

enter image description here

enter image description here

enter image description here

So you see that when both the boxplot and scatter are added to the same figure the labels no longer align correctly. How can I fix this alignment?


Solution

    • Tested in python 3.8.11, pandas 1.3.2, matplotlib 3.4.3, seaborn 0.11.2
    • Notice the xticklabel locations are misaligned.
    • As per matplotlib.pyplot.boxplot, position defaults to range(1, N+1)
    fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(6, 8))
    ax1.boxplot(df, labels=names)
    print(ax1.get_xticks())
    ax2.scatter(names, df.head(1))
    print(ax2.get_xticks())
    
    ax3.boxplot(df, labels=names)
    ax3.scatter(names, df.head(1))
    [out]:
    [1 2 3 4 5 6]
    [0, 1, 2, 3, 4, 5]
    

    enter image description here

    • A correct solution, given the existing code, is to set the positions parameter
    • This also requires converting the dataframe to long form with pandas.DataFrame.melt, for the scatterplot.
    plt.boxplot(df, labels=names, positions=range(len(df.columns)))
    plt.scatter(data=df.melt(), x='variable', y='value')
    

    enter image description here

    ax = df.plot(kind='box', positions=range(len(df.columns)))
    df.melt().plot(kind='scatter', x='variable', y='value', ax=ax)
    

    enter image description here

    import seaborn as sns
    
    sns.boxplot(data=df, boxprops={'facecolor':'None'})
    print(plt.xticks())
    sns.swarmplot(data=df)
    print(plt.xticks())
    
    [out]:
    (array([0, 1, 2, 3, 4, 5]), [Text(0, 0, 'a'), Text(1, 0, 'b'), Text(2, 0, 'c'), Text(3, 0, 'd'), Text(4, 0, 'e'), Text(5, 0, 'f')])
    (array([0, 1, 2, 3, 4, 5]), [Text(0, 0, 'a'), Text(1, 0, 'b'), Text(2, 0, 'c'), Text(3, 0, 'd'), Text(4, 0, 'e'), Text(5, 0, 'f')])
    

    enter image description here