Search code examples
pythonmatplotlibplotseaborn

Python matplotlib / Seaborn stripplot with connection between points


I'm using Python 3 and Seaborn to make categorical stripplots (see code and image below).

Each stripplot has 2 data points (one for each gender).

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns


df = [["city2", "f", 300],
    ["city2", "m", 39],
    ["city1", "f", 95],
    ["city1", "m", 53]]

df = pd.DataFrame(df, columns = ["city", "gender", "variable"])

sns.stripplot(data=df,x='city',hue='gender',y='variable', size=10, linewidth=1)

I get the following output enter image description here

However, I would like to have a line segment connecting the male and female points. I would like the figure to look like this (see pic below). However, I manually drew those red lines and I'm wondering if there is a simple way to do it w/ Seaborn or matplotlib. Thank you! enter image description here


Solution

  • You can create a list of f-m pairs using pandas.dataframe.groupby and then plot the segments between pairs:

    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib import collections  as mc
    import pandas as pd
    import seaborn as sns
    
    
    df = [["city2", "f", 300],
          ["city2", "m", 39],
          ["city1", "f", 95],
          ["city1", "m", 53],
          ["city4", "f", 200],
          ["city3", "f", 100],
          ["city4", "m", 236],
          ["city3", "m", 20],]
    
    
    df = pd.DataFrame(df, columns = ["city", "gender", "variable"])
    
    
    ax = sns.stripplot(data=df,x='city',hue='gender',y='variable', size=10, linewidth=1)
    
    lines = ([[x, n] for n in group] for x, (_, group) in enumerate(df.groupby(['city'], sort = False)['variable']))
    lc = mc.LineCollection(lines, colors='red', linewidths=2)    
    ax.add_collection(lc)
    
    sns.plt.show()
    

    Output:

    enter image description here