I am trying to reproduce stripplots exactly so that I can draw lines and write on them reliably. However, when I produce a stripplot with jitter the jitter is random and prevents me from achieving my goal.
I have blindly tried some rcParams
I found in other Stack Overflow posts, such as mpl.rcParams['svg.hashsalt']
which hasn't worked. I also tried setting a seed for random.seed()
without success.
The code I am running looks like the following.
import seaborn as sns
import matplotlib.pyplot as plt
import random
plt.figure(figsize=(14,9))
random.seed(123)
catagories = []
values = []
for i in range(0,200):
n = random.randint(1,3)
catagories.append(n)
for i in range(0,200):
n = random.randint(1,100)
values.append(n)
sns.stripplot(catagories, values, size=5)
plt.title('Random Jitter')
plt.xticks([0,1,2],[1,2,3])
plt.show()
This code generates a stripplot
just like I want. But if you run the code twice you will get different placements for the points, due to the jitter. The graph I am making requires jitter to not look ridiculous, but I want to write on the graph. However there is no way to know the exact positions of the points before running the code, which then changes every time the code is run.
Is there any way to set a seed for the jitter in seaborn stripplots
to make them perfectly reproduceable?
scipy.stats.uniform
uniform
is class uniform_gen(scipy.stats._distn_infrastructure.rv_continuous)
class rv_continuous(rv_generic)
seed
parameter, and uses np.random
np.random.seed()
np.random.seed(123)
must be inside the loop.jitter : float, ``True``/``1`` is special-cased, optional
Amount of jitter (only along the categorical axis) to apply. This
can be useful when you have many points and they overlap, so that
it is easier to see the distribution. You can specify the amount
of jitter (half the width of the uniform random variable support),
or just use ``True`` for a good default.
class _StripPlotter
in categorical.py
scipy.stats.uniform
from scipy import stats
class _StripPlotter(_CategoricalScatterPlotter):
"""1-d scatterplot with categorical organization."""
def __init__(self, x, y, hue, data, order, hue_order,
jitter, dodge, orient, color, palette):
"""Initialize the plotter."""
self.establish_variables(x, y, hue, data, orient, order, hue_order)
self.establish_colors(color, palette, 1)
# Set object attributes
self.dodge = dodge
self.width = .8
if jitter == 1: # Use a good default for `jitter = True`
jlim = 0.1
else:
jlim = float(jitter)
if self.hue_names is not None and dodge:
jlim /= len(self.hue_names)
self.jitterer = stats.uniform(-jlim, jlim * 2).rvs
seed : {None, int, `~np.random.RandomState`, `~np.random.Generator`}, optional
This parameter defines the object to use for drawing random variates.
If `seed` is `None` the `~np.random.RandomState` singleton is used.
If `seed` is an int, a new ``RandomState`` instance is used, seeded
with seed.
If `seed` is already a ``RandomState`` or ``Generator`` instance,
then that object is used.
Default is None.
np.random.seed
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
fig, axes = plt.subplots(2, 3, figsize=(12, 12))
for x in range(6):
np.random.seed(123)
catagories = []
values = []
for i in range(0,200):
n = np.random.randint(1,3)
catagories.append(n)
for i in range(0,200):
n = np.random.randint(1,100)
values.append(n)
row = x // 3
col = x % 3
axcurr = axes[row, col]
sns.stripplot(catagories, values, size=5, ax=axcurr)
axcurr.set_title(f'np.random jitter {x+1}')
plt.show()
random
import seaborn as sns
import matplotlib.pyplot as plt
import random
fig, axes = plt.subplots(2, 3, figsize=(12, 12))
for x in range(6):
random.seed(123)
catagories = []
values = []
for i in range(0,200):
n = random.randint(1,3)
catagories.append(n)
for i in range(0,200):
n = random.randint(1,100)
values.append(n)
row = x // 3
col = x % 3
axcurr = axes[row, col]
sns.stripplot(catagories, values, size=5, ax=axcurr)
axcurr.set_title(f'random jitter {x+1}')
plt.show()
random
for the data and np.random.seed
for the plotfig, axes = plt.subplots(2, 3, figsize=(12, 12))
for x in range(6):
random.seed(123)
catagories = []
values = []
for i in range(0,200):
n = random.randint(1,3)
catagories.append(n)
for i in range(0,200):
n = random.randint(1,100)
values.append(n)
row = x // 3
col = x % 3
axcurr = axes[row, col]
np.random.seed(123)
sns.stripplot(catagories, values, size=5, ax=axcurr)
axcurr.set_title(f'np.random jitter {x+1}')
plt.show()