Search code examples
pythonmatplotlibjupyter-notebookseabornjupyter

Change marker style by a dataframe column (categorical) in seaborn stripplot


I was looking to visualise a categorical variable as marker style in seaborn stripplot, but it does not seem to be possible easily. Is there an easy way to do this. For example, I'm trying to run this code

tips = sns.load_dataset("tips")
sns.stripplot(x="day", y="total_bill", hue="time", style="sex", jitter=True, data=tips)

which fails. An alternative is to use relplot which does provide the option but has no way to insert jitter which makes the visualisation less nice.

sns.relplot(x="day", y="total_bill", hue="time", data=tips, style="sex")

works providing this

enter image description here

Is there any way of doing this using stripplot/catplot/swarmplot?

EDIT: This question is related. However the solution there does not seem to allow generation of a legend for size (and is quite dated).


Solution

  • sns.relplot is a figure-level function which relies on the axes-level function sns.scatterplot. sns.scatterplot has a parameter x_jitter which unfortunately currently has no effect (seaborn 0.11.2).

    You can mimic the functionality by grasping the positions of the points, add some random jitter and assigning these positions again.

    Here is an example:

    from matplotlib import pyplot as plt
    import seaborn as sns
    import numpy as np
    
    tips = sns.load_dataset("tips")
    ax = sns.scatterplot(x="day", y="total_bill", hue="time", data=tips, style="sex")
    for points in ax.collections:
        vertices = points.get_offsets().data
        if len(vertices) > 0:
            vertices[:, 0] += np.random.uniform(-0.3, 0.3, vertices.shape[0])
            points.set_offsets(vertices)
    xticks = ax.get_xticks()
    ax.set_xlim(xticks[0] - 0.5, xticks[-1] + 0.5) # the limits need to be moved to show all the jittered dots
    sns.move_legend(ax, bbox_to_anchor=(1.01, 1.02), loc='upper left')  # needs seaborn 0.11.2
    sns.despine()
    plt.tight_layout()
    plt.show()
    

    sns.scatterplot with jitter

    With sns.relplot you could iterate through all the subplots:

    g = sns.relplot(x="day", y="total_bill", hue="time", data=tips, style="sex")
    for ax in g.axes.flat:
        for points in ax.collections:
            vertices = points.get_offsets().data
            if len(vertices) > 0:
                vertices[:, 0] += np.random.uniform(-0.3, 0.3, vertices.shape[0])
                points.set_offsets(vertices)
        xticks = ax.get_xticks()
        ax.set_xlim(xticks[0] - 0.5, xticks[-1] + 0.5) # the limits need to be moved to show all the jittered dots
    plt.show()