Search code examples
pythonmatplotlib

What is the difference between drawing plots using plot, axes or figure in matplotlib?


I'm kind of confused what is going at the backend when I draw plots in matplotlib, tbh, I'm not clear with the hierarchy of plot, axes and figure. I read the documentation and it was helpful but I'm still confused...

The below code draws the same plot in three different ways -

#creating the arrays for testing
x = np.arange(1, 100)
y = np.sqrt(x)
#1st way
plt.plot(x, y)
#2nd way
ax = plt.subplot()
ax.plot(x, y)
#3rd way
figure = plt.figure()
new_plot = figure.add_subplot(111)
new_plot.plot(x, y)

Now my question is -

  1. What is the difference between all the three, I mean what is going under the hood when any of the 3 methods are called?

  2. Which method should be used when and what are the pros and cons of using any on those?


Solution

  • Method 1

    plt.plot(x, y)
    

    This lets you plot just one figure with (x,y) coordinates. If you just want to get one graphic, you can use this way.

    Method 2

    ax = plt.subplot()
    ax.plot(x, y)
    

    This lets you plot one or several figure(s) in the same window. As you write it, you will plot just one figure, but you can make something like this:

    fig1, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
    

    You will plot 4 figures which are named ax1, ax2, ax3 and ax4 each one but on the same window. This window will be just divided in 4 parts with my example.

    Method 3

    fig = plt.figure()
    new_plot = fig.add_subplot(111)
    new_plot.plot(x, y)
    

    I didn't use it, but you can find documentation.

    Example:

    import numpy as np
    import matplotlib.pyplot as plt
    
    # Method 1 #
    
    x = np.random.rand(10)
    y = np.random.rand(10)
    
    figure1 = plt.plot(x,y)
    
    # Method 2 #
    
    x1 = np.random.rand(10)
    x2 = np.random.rand(10)
    x3 = np.random.rand(10)
    x4 = np.random.rand(10)
    y1 = np.random.rand(10)
    y2 = np.random.rand(10)
    y3 = np.random.rand(10)
    y4 = np.random.rand(10)
    
    figure2, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
    ax1.plot(x1,y1)
    ax2.plot(x2,y2)
    ax3.plot(x3,y3)
    ax4.plot(x4,y4)
    
    plt.show()
    

    enter image description here enter image description here

    Other example:

    enter image description here