Search code examples
pythonmatplotlibvisualization

How to create a plot in Matplotlib that looks like a swarmplot but with overlapping points?


I'm trying to create a (sort of) swarmplot - that it- should clearly show the shapes of the distributions but allow for quick plotting of tens of thousands of datapoints by overlapping the dot representations of datapoints. Such as this:

enter image description here

My idea is to essentially create a dot plot but divide each distribution into quantiles and apply jitter to the horizontal positions of the datapoints whose magnitude is proportional to the number of points in the given quartile. This works fine when the distributions are of the same size but I need some way of scaling the jitter so that when one of the distributions has only a few datapoints the dots representing them would be arranged on a (nearly) vertical line i.e. NOT like below:

enter image description here

Here is my code for plot creation:

import matplotlib.pyplot as plt
import numpy as np


def fancy_distribution_plot(distributions: list, tick_labels: list, max_plot_width: int = 1, alpha=0.7,
                            number_of_segments=12,
                            separation_between_plots=0.1,
                            separation_between_subplots=0.1,
                            vertical_limits=None,
                            grid=False,
                            remove_outlier_above_segment=None,
                            remove_outlier_below_segment=None,
                            y_label=None,
                            title=None):
    fig, ax = plt.subplots()

    number_of_plots = len(distributions)
    # print(f" number of plots {number_of_plots}")
    # print(f" max x line {number_of_plots * (max_plot_wwidth + separation_between_plots) + separation_between_plots}")

    ax.set_xlim(left=0, right=number_of_plots * (max_plot_width + separation_between_plots) + separation_between_plots)

    ticks = [separation_between_plots + max_plot_width / 2 + (max_plot_width + separation_between_plots) * i
             for i in range(0, number_of_plots)]
    print(ticks)

    for i in range(len(distributions)):
        distribution = distributions[i]
        # print(f"distribution {distribution}")
        segments = np.linspace(np.min(distribution), np.max(distribution), number_of_segments + 1)[1:-1]
        # print(f"segments {segments}")
        segment_indices = number_of_segments - 1 - np.where(segments[:, None] >= distribution[None, :], 1, 0).sum(0)
        # print(f"quantile indices {segment_indices}")
        if remove_outlier_above_segment:
            a = remove_outlier_above_segment[i]
            distribution = distribution[segment_indices <= a]
            segment_indices = segment_indices[segment_indices <= a]

        if remove_outlier_below_segment:
            b = remove_outlier_below_segment[i]
            distribution = distribution[segment_indices >= b - 1]
            segment_indices = segment_indices[segment_indices >= b - 1]

        values, counts = np.unique(segment_indices, return_counts=True)
        # print(f"values {values}")
        # print(f"counts {counts}")
        counts_filled = []
        j = 0
        for k in range(number_of_segments):
            if k in values:
                counts_filled.append(counts[j])
                j += 1
            else:
                counts_filled.append(0)
        variances = (max_plot_width / 2) * (counts_filled / np.max(counts))
        # print(f"variances {variances}")
        jitter_unadjusted = np.random.uniform(-1, 1, len(distribution))
        jitter = np.take(variances, segment_indices) * jitter_unadjusted
        # print(f"jitter {jitter}")
        ax.scatter(jitter + ticks[i], distribution, alpha=alpha)

    ax.set_xticks(ticks)
    ax.set_xticklabels(tick_labels)
    if vertical_limits:
        ax.set_ylim(bottom=vertical_limits[0], top=vertical_limits[1])
    if not grid:
        ax.grid(False)
    if y_label:
        ax.set_ylabel(y_label)
    if title:
        ax.set_title(title)
    plt.show()

And the code to recreate the second chart above:

# Create example random data
np.random.seed(0)
distro1 = np.random.normal(0, 2, 4)
distr2 = np.random.normal(1, 1, 10)
distr3 = np.random.normal(2, 3, 1000)

distributions = [distro1, distr2, distr3]
fancy_distribution_plot(distributions, tick_labels=['distro1', 'distro2', 'distro3'], number_of_segments=100,
                        grid=False)


Solution

  • Expanding on my comment, you could scale the variance (and thus the jitter) dividing by the max count amongst all of the distributions.

    A possible implementation (starting from your function) is:

    import matplotlib.pyplot as plt
    import numpy as np
    
    
    def fancy_distribution_plot(distributions: list, tick_labels: list, max_plot_width: int = 1, alpha=0.7,
                                number_of_segments=12,
                                separation_between_plots=0.1,
                                separation_between_subplots=0.1,
                                vertical_limits=None,
                                grid=False,
                                remove_outlier_above_segment=None,
                                remove_outlier_below_segment=None,
                                y_label=None,
                                title=None):
        fig, ax = plt.subplots()
    
        number_of_plots = len(distributions)
    
        ax.set_xlim(left=0, right=number_of_plots * (max_plot_width + separation_between_plots) + separation_between_plots)
    
        ticks = [separation_between_plots + max_plot_width / 2 + (max_plot_width + separation_between_plots) * i
                 for i in range(0, number_of_plots)]
        
        max_counts = 0.0
        counts_filled_list = []
        segment_indices_list = []
        for i in range(len(distributions)):
            distribution = distributions[i]
            
            segments = np.linspace(np.min(distribution), np.max(distribution), number_of_segments + 1)[1:-1]
            
            segment_indices = number_of_segments - 1 - np.where(segments[:, None] >= distribution[None, :], 1, 0).sum(0)
            
            if remove_outlier_above_segment:
                a = remove_outlier_above_segment[i]
                distribution = distribution[segment_indices <= a]
                segment_indices = segment_indices[segment_indices <= a]
    
            if remove_outlier_below_segment:
                b = remove_outlier_below_segment[i]
                distribution = distribution[segment_indices >= b - 1]
                segment_indices = segment_indices[segment_indices >= b - 1]
            segment_indices_list.append(segment_indices)
    
            values, counts = np.unique(segment_indices, return_counts=True)
            if np.max(counts) > max_counts:
                max_counts = np.max(counts)
            counts_filled = []
            j = 0
            for k in range(number_of_segments):
                if k in values:
                    counts_filled.append(counts[j])
                    j += 1
                else:
                    counts_filled.append(0)
            counts_filled_list.append(counts_filled)
    
        for i in range(len(distributions)):    
            #print(f"counts filled {counts_filled}")
            variances = (max_plot_width / 2) * (counts_filled_list[i] / max_counts)
            #print(f"variances {variances}")
            jitter_unadjusted = np.random.uniform(-1, 1, len(distributions[i])) 
            jitter = np.take(variances, segment_indices_list[i]) * jitter_unadjusted
    
            # print(f"jitter {jitter}")
            ax.scatter(jitter + ticks[i], distributions[i], alpha=alpha)
    
        ax.set_xticks(ticks)
        ax.set_xticklabels(tick_labels)
        if vertical_limits:
            ax.set_ylim(bottom=vertical_limits[0], top=vertical_limits[1])
        if not grid:
            ax.grid(False)
        if y_label:
            ax.set_ylabel(y_label)
        if title:
            ax.set_title(title)
        plt.show()
    

    That from the data in your toy example gives

    Swarmplots

    The code is quite messy and duplicating the for loops is nor very elegant nor efficient: I hope at least the result is what you were looking for.