Search code examples
pythonmatplotlibseabornswarmplot

Plotting colored lines connecting individual data points of two swarmplots


I have:

import numpy as np
import pandas as pd
import seaborn as sb
import matplotlib.pyplot as plt

# Generate random data
set1 = np.random.randint(0, 40, 24)
set2 = np.random.randint(0, 100, 24)

# Put into dataframe and plot
df = pd.DataFrame({'set1': set1, 'set2': set2})
data = pd.melt(df)
sb.swarmplot(data=data, x='variable', y='value')

The two random distributions plotted with seaborn's swarmplot function: The two random distributions plotted with seaborns swarmplot function

I want the individual plots of both distributions to be connected with a colored line such that the first data point of set 1 in the dataframe is connected with the first data point of set 2. I realize that this would probably be relatively simple without seaborn but I want to keep the feature that the individual data points do not overlap. Is there any way to access the individual plot coordinates in the seaborn swarmfunction?


Solution

  • EDIT: Thanks to @Mead, who pointed out a bug in my post prior to 2021-08-23 (I forgot to sort the locations in the prior version).

    I gave the nice answer by Paul Brodersen a try, and despite him saying that

    Madness lies this way

    ... I actually think it's pretty straight forward and yields nice results:

    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import seaborn as sns
    
    # Generate random data
    rng = np.random.default_rng(42)
    set1 = rng.integers(0, 40, 5)
    set2 = rng.integers(0, 100, 5)
    
    # Put into dataframe
    df = pd.DataFrame({"set1": set1, "set2": set2})
    print(df)
    data = pd.melt(df)
    
    # Plot
    fig, ax = plt.subplots()
    sns.swarmplot(data=data, x="variable", y="value", ax=ax)
    
    # Now connect the dots
    # Find idx0 and idx1 by inspecting the elements return from ax.get_children()
    # ... or find a way to automate it
    idx0 = 0
    idx1 = 1
    locs1 = ax.get_children()[idx0].get_offsets()
    locs2 = ax.get_children()[idx1].get_offsets()
    
    # before plotting, we need to sort so that the data points
    # correspond to each other as they did in "set1" and "set2"
    sort_idxs1 = np.argsort(set1)
    sort_idxs2 = np.argsort(set2)
    
    # revert "ascending sort" through sort_idxs2.argsort(),
    # and then sort into order corresponding with set1
    locs2_sorted = locs2[sort_idxs2.argsort()][sort_idxs1]
    
    for i in range(locs1.shape[0]):
        x = [locs1[i, 0], locs2_sorted[i, 0]]
        y = [locs1[i, 1], locs2_sorted[i, 1]]
        ax.plot(x, y, color="black", alpha=0.1)
    

    It prints:

       set1  set2
    0     3    85
    1    30     8
    2    26    69
    3    17    20
    4    17     9
    

    And you can see that the data is linked correspondingly in the plot.

    enter image description here

    UPDATE

    If you don't want to find the indices into ax.get_children() manually or by some other means, you could also use the function below, like locs1, locs2 = find_locs(ax, 2, 5) for the present example.

    def find_locs(ax, ncols, ndots):
        """Find objects in axes corresponding to plotted dots.
    
        Parameters
        ----------
        ax : plt.Axes
            The axes of the plot.
        ncols : int
            The number of stripplot columns in the plot.
        ndots : int
            The number of dots per column in the plot.
    
        Returns
        -------
        locs : list of np.ndarray
            `locs` is of length `ncols`, with each np.ndarray in `locs` corresponding
            to a column in the plot. The np.ndarray  is of shape (`ndots`, 2),
            corresponding to the (x,y) offset of each dot in that column.
        """
        # see also https://stackoverflow.com/a/63171175/5201771
        locs = []
        for child in ax.get_children():
            try:
                offsets = child.get_offsets()
            except AttributeError:
                continue
    
            _r, _c = offsets.shape  # _c is 2 for "x" and "y"
            if _c == 2 and _r == ndots:
                locs.append(offsets)
    
        if len(locs) == ncols:
            return locs
        else:
            raise RuntimeError("Encountered problems identifying dots in plot.")