Search code examples
pythonmatplotlibseabornjitter

How can I add jitter to my seaborn and matplot plots?


I am working on trying to add Jitter to my plots using seaborn and matplot plots. I am getting mixed information form what I am reading online. Some information is saying coding needs to be done and other information show it as being as simple as jitter = True. I there another library or something that I should be importing that I am not aware of? Below is the code that I am running and trying to add jitter to:

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

filename = 'https://library.startlearninglabs.uw.edu/DATASCI410/Datasets/JitteredHeadCount.csv'
headcount_df = pd.read_csv(filename)
headcount_df.describe()

%matplotlib inline
ax = plt.figure(figsize=(12, 6)).gca() # define axis
headcount_df.plot.scatter(x = 'Hour', y = 'TablesOpen', ax = ax, alpha = 0.2)
# auto_price.plot(kind = 'scatter', x = 'city-mpg', y = 'price', ax = ax)
ax.set_title('Hour vs TablesOpen') # Give the plot a main title
ax.set_ylabel('TablesOpen')# Set text for y axis
ax.set_xlabel('Hour')

ax = sns.kdeplot(headcount_df.loc[:, ['TablesOpen', 'Hour']], shade = True, cmap = 'PuBu')
headcount_df.plot.scatter(x = 'Hour', y = 'TablesOpen', ax = ax, jitter = True)
ax.set_title('Hour vs TablesOpen') # Give the plot a main title
ax.set_ylabel('TablesOpen')# Set text for y axis
ax.set_xlabel('Hour')

I receive the error: AttributeError: 'PathCollection' object has no property 'jitter' when trying to add the jitter. Any help or more information on this would be much appreciated


Solution

  • To add jitter to a scatter plot, first get a handle to the collection that contains the scatter dots. When a scatter plot is just created on an ax, ax.collections[-1] will be the desired collection.

    Calling get_offsets() on the collection gets all the xy coordinates of the dots. Add some small random number to each of them. As in this case all coordinates are integers, adding a random number between 0 and 1 spreads the dots out evenly.

    In this case the number of dots is very huge. To better see where the dots are concentrated, they can be made very small (marker=',', linewidth=0, s=1,) and be very transparent (e.g.alpha=0.1).

    import matplotlib.pyplot as plt
    import pandas as pd
    import numpy as np
    
    filename = 'https://library.startlearninglabs.uw.edu/DATASCI410/Datasets/JitteredHeadCount.csv'
    headcount_df = pd.read_csv(filename)
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    headcount_df.plot.scatter(x='Hour', y='TablesOpen', marker=',', linewidth=0, s=1, alpha=.1, color='crimson', ax=ax)
    dots = ax.collections[-1]
    offsets = dots.get_offsets()
    jittered_offsets = offsets + np.random.uniform(0, 1, offsets.shape)
    dots.set_offsets(jittered_offsets)
    
    ax.set_title('Hour vs TablesOpen')  # Give the plot a main title
    ax.set_ylabel('TablesOpen')  # Set text for y axis
    ax.set_xlabel('Hour')
    ax.set_xticks(range(25))
    ax.autoscale(enable=True, tight=True)
    
    plt.tight_layout()
    plt.show()
    

    scatter plot with jitter

    As there are a huge number of points, drawing the 2D kde takes a long time. The time can be reduced by taking a random sample from the rows. Note that to draw a 2D kde, the latest versions of Seaborn want each column as a separate parameter.

    import matplotlib.pyplot as plt
    import pandas as pd
    import numpy as np
    import seaborn as sns
    
    filename = 'https://library.startlearninglabs.uw.edu/DATASCI410/Datasets/JitteredHeadCount.csv'
    headcount_df = pd.read_csv(filename)
    
    fig, ax = plt.subplots(figsize=(12, 6))
    
    N = 5000
    rand_sel_df = headcount_df.iloc[np.random.choice(range(len(headcount_df)), N)]
    ax = sns.kdeplot(rand_sel_df['Hour'], rand_sel_df['TablesOpen'], shade=True, cmap='PuBu', ax=ax)
    
    ax.set_title('Hour vs TablesOpen')
    ax.set_xticks(range(25))
    
    plt.tight_layout()
    plt.show()
    

    kdeplot