I was looking to visualise a categorical variable as marker style in seaborn stripplot, but it does not seem to be possible easily. Is there an easy way to do this. For example, I'm trying to run this code
tips = sns.load_dataset("tips")
sns.stripplot(x="day", y="total_bill", hue="time", style="sex", jitter=True, data=tips)
which fails. An alternative is to use relplot which does provide the option but has no way to insert jitter
which makes the visualisation less nice.
sns.relplot(x="day", y="total_bill", hue="time", data=tips, style="sex")
works providing this
Is there any way of doing this using stripplot/catplot/swarmplot?
EDIT: This question is related. However the solution there does not seem to allow generation of a legend for size (and is quite dated).
sns.relplot
is a figure-level function which relies on the axes-level function sns.scatterplot
. sns.scatterplot
has a parameter x_jitter
which unfortunately currently has no effect (seaborn 0.11.2).
You can mimic the functionality by grasping the positions of the points, add some random jitter and assigning these positions again.
Here is an example:
from matplotlib import pyplot as plt
import seaborn as sns
import numpy as np
tips = sns.load_dataset("tips")
ax = sns.scatterplot(x="day", y="total_bill", hue="time", data=tips, style="sex")
for points in ax.collections:
vertices = points.get_offsets().data
if len(vertices) > 0:
vertices[:, 0] += np.random.uniform(-0.3, 0.3, vertices.shape[0])
points.set_offsets(vertices)
xticks = ax.get_xticks()
ax.set_xlim(xticks[0] - 0.5, xticks[-1] + 0.5) # the limits need to be moved to show all the jittered dots
sns.move_legend(ax, bbox_to_anchor=(1.01, 1.02), loc='upper left') # needs seaborn 0.11.2
sns.despine()
plt.tight_layout()
plt.show()
With sns.relplot
you could iterate through all the subplots:
g = sns.relplot(x="day", y="total_bill", hue="time", data=tips, style="sex")
for ax in g.axes.flat:
for points in ax.collections:
vertices = points.get_offsets().data
if len(vertices) > 0:
vertices[:, 0] += np.random.uniform(-0.3, 0.3, vertices.shape[0])
points.set_offsets(vertices)
xticks = ax.get_xticks()
ax.set_xlim(xticks[0] - 0.5, xticks[-1] + 0.5) # the limits need to be moved to show all the jittered dots
plt.show()