Search code examples
pythonmatplotlibscatter-plotpad

Scatter plot do not adjust properly plot range in matplotlib


I am plotting two gaussians (one centered in 0 and the other one in 100) with plt.plot and plt.scatter in matplotlib version 2.2.3. For any reason the subplot does not adjust the plot range automatically for the case of the second curve in the scatter plot.

Of course I can do it manually --in this simple case-- but actually what I have is a big grid and I don't want to set the range one by one.

What is this happening? Is there any way to fix it?

This is my code:

import numpy as np
import matplotlib.pyplot as plt

mu1, sigma1 = 0, 1
x1 = mu1 + sigma1 * np.random.randn(10000)
hist1, bins1 = np.histogram(x1, bins='auto', density=True)
center1 = (bins1[:-1] + bins1[1:]) / 2

mu2, sigma2 = 100, 15
x2 = mu2 + sigma2 * np.random.randn(10000)
hist2, bins2 = np.histogram(x2, bins='auto', density=True)
center2 = (bins2[:-1] + bins2[1:]) / 2

plt.subplot(2, 2, 1)
plt.plot(center1, hist1)
plt.text(2, 0.27, 'plot\n$\\mu$ = 0 \n$\\sigma$ = 1')
plt.subplot(2, 2, 2)
plt.scatter(center1, hist1)
plt.text(2, 0.27, 'scatter\n$\\mu$ = 0 \n$\\sigma$ = 1')
plt.subplot(2, 2, 3)
plt.plot(center2, hist2)
plt.text(127, 0.02, 'plot\n$\\mu$ = 100 \n$\\sigma$ = 15')
plt.subplot(2, 2, 4)
plt.scatter(center2, hist2)
plt.text(127, 0.02, 'scatter\n$\\mu$ = 100 \n$\\sigma$ = 15')

plt.show()

So the output is: grid_gaussians

I'd be glad if someone could help with that, thanks in advance. Any answer or comment would be appreciated.


Solution

  • Autoscaling of collections (scatter produces a PathCollection) is still an unsolved problem, although there are ideas for workarounds being discussed.

    A strange hacky solution in the case of the example above is to add an empty plot, plt.plot() to the axes before creating the scatter.

    import numpy as np
    import matplotlib.pyplot as plt
    
    mu1, sigma1 = 0, 1
    x1 = mu1 + sigma1 * np.random.randn(10000)
    hist1, bins1 = np.histogram(x1, bins='auto', density=True)
    center1 = (bins1[:-1] + bins1[1:]) / 2
    
    mu2, sigma2 = 100, 15
    x2 = mu2 + sigma2 * np.random.randn(10000)
    hist2, bins2 = np.histogram(x2, bins='auto', density=True)
    center2 = (bins2[:-1] + bins2[1:]) / 2
    
    
    plt.subplot(2, 2, 1)
    plt.plot(center1, hist1)
    plt.text(2, 0.27, 'plot\n$\\mu$ = 0 \n$\\sigma$ = 1')
    plt.subplot(2, 2, 2)
    plt.plot()                      ## <== empty plot
    plt.scatter(center1, hist1)
    plt.text(2, 0.27, 'scatter\n$\\mu$ = 0 \n$\\sigma$ = 1')
    plt.subplot(2, 2, 3)
    plt.plot(center2, hist2)
    plt.text(127, 0.02, 'plot\n$\\mu$ = 100 \n$\\sigma$ = 15')
    plt.subplot(2, 2, 4)
    plt.plot()                      ## <== empty plot
    plt.scatter(center2, hist2)
    plt.text(127, 0.02, 'scatter\n$\\mu$ = 100 \n$\\sigma$ = 15')
    
    plt.show()
    

    enter image description here

    The above is more a joke, though it works in this particular case. A more serious solution would be to create a plot of the actual data and remove it directly afterwards. This is enough to let the autoscaling work as expected for the data range of the scatter.

    import numpy as np
    import matplotlib.pyplot as plt
    
    mu1, sigma1 = 0, 1
    x1 = mu1 + sigma1 * np.random.randn(10000)
    hist1, bins1 = np.histogram(x1, bins='auto', density=True)
    center1 = (bins1[:-1] + bins1[1:]) / 2
    
    mu2, sigma2 = 100, 15
    x2 = mu2 + sigma2 * np.random.randn(10000)
    hist2, bins2 = np.histogram(x2, bins='auto', density=True)
    center2 = (bins2[:-1] + bins2[1:]) / 2
    
    
    plt.subplot(2, 2, 1)
    plt.plot(center1, hist1)
    plt.text(2, 0.27, 'plot\n$\\mu$ = 0 \n$\\sigma$ = 1')
    plt.subplot(2, 2, 2)
    sentinel, = plt.plot(center1, hist1)            ## <== sentinel plot
    sentinel.remove()
    plt.scatter(center1, hist1)
    plt.text(2, 0.27, 'scatter\n$\\mu$ = 0 \n$\\sigma$ = 1')
    plt.subplot(2, 2, 3)
    plt.plot(center2, hist2)
    plt.text(127, 0.02, 'plot\n$\\mu$ = 100 \n$\\sigma$ = 15')
    plt.subplot(2, 2, 4)
    sentinel, = plt.plot(center2, hist2)            ## <== sentinel plot
    sentinel.remove()
    plt.scatter(center2, hist2)
    plt.text(127, 0.02, 'scatter\n$\\mu$ = 100 \n$\\sigma$ = 15')
    
    
    plt.show()
    

    enter image description here

    Finally, consider that in the case of a large grid of plots you currently need to adjust the position of the text manually anyways. So the real solution here would be to create a function that is called for each axes and let that do everything automatically.

    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.offsetbox import AnchoredText
    
    def plot_my_hist(mu, sigma, ax=None):
        ax = ax or plt.gca()
        x = mu + sigma * np.random.randn(10000)
        hist, bins = np.histogram(x, bins='auto', density=True)
        center = (bins[:-1] + bins[1:]) / 2
        # Plot
        sentinel, = ax.plot(center, hist)      ## <== sentinel plot
        sentinel.remove()
        ax.scatter(center, hist)
        # Annotation
        at = AnchoredText(f'scatter\n$\\mu$ = {mu} \n$\\sigma$ = {sigma}',
                          loc='upper right')
        ax.add_artist(at)
    
    mus = [0, 0, 12, 12, 100, 100]
    sigmas = [1, 15, 1, 15, 1, 15]
    fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(10,6))
    
    for ax, mu, sigma in zip(axes.T.flat, mus, sigmas):
        plot_my_hist(mu, sigma, ax=ax)
    
    
    fig.tight_layout()
    plt.show()
    

    enter image description here