Search code examples
pythonmatplotlibplotscatter-plot

Adding y=x to a matplotlib scatter plot if I haven't kept track of all the data points that went in


Here's some code that does scatter plot of a number of different series using matplotlib and then adds the line y=x:

import numpy as np, matplotlib.pyplot as plt, matplotlib.cm as cm, pylab

nseries = 10
colors = cm.rainbow(np.linspace(0, 1, nseries))

all_x = []
all_y = []
for i in range(nseries):
    x = np.random.random(12)+i/10.0
    y = np.random.random(12)+i/5.0
    plt.scatter(x, y, color=colors[i])
    all_x.extend(x)
    all_y.extend(y)

# Could I somehow do the next part (add identity_line) if I haven't been keeping track of all the x and y values I've seen?
identity_line = np.linspace(max(min(all_x), min(all_y)),
                            min(max(all_x), max(all_y)))
plt.plot(identity_line, identity_line, color="black", linestyle="dashed", linewidth=3.0)

plt.show()

In order to achieve this I've had to keep track of all the x and y values that went into the scatter plot so that I know where identity_line should start and end. Is there a way I can get y=x to show up even if I don't have a list of all the points that I plotted? I would think that something in matplotlib can give me a list of all the points after the fact, but I haven't been able to figure out how to get that list.


Solution

  • Starting with matplotlib 3.3 this has been made very simple with the axline method which only needs a point and a slope. To plot x=y:

    ax.axline((0, 0), slope=1)
    

    You don't need to look at your data to use this because the point you specify (i.e. here (0,0)) doesn't actually need to be in your data or plotting range.