Search code examples
pythonmatplotlibfigure

Create a figure based on another one


I want to generate a figure inside a function, and then be able to add additional stuff to that plot in another function. I would like both figures (the original and the edited one) to be available for future usage. Something like:

import numpy as np
import matplotlib.pyplot as plt

def plot_1():
    X, Y = np.meshgrid(np.linspace(0, 10, 100), np.linspace(0, 10, 50))
    z = np.random.rand(50, 100)

    fig, ax = plt.subplots()
    ax.contourf(X, Y, z, cmap="viridis")

    return fig

def plot_2(fig):
    ax = fig.axes[0]
    ax.scatter([2, 5], [1, 4], zorder=2.5, color="r")

    return ax

f = plot_1()
f2 = plot_2(f)

However, this changes the original image (which I would like to stay as it originally was) and doesn't return a printable figure in f2. At first I thought the problem was that the scatter plot was not being done but as suggested by Lucas in the comments that was not the actual issue, as it could be solved with a correct value for zorder.

How can I get this right?


Solution

  • If I understand correctly: you want to have two figures, plot the same in both and then only in one make some extra plots.

    The way to do this is to create the figures outside the functions and pass the axes to the functions:

    import numpy as np
    import matplotlib.pyplot as plt
    
    def plot_1(ax):
        X, Y = np.meshgrid(np.linspace(0, 10, 100), np.linspace(0, 10, 50))
        z = np.random.rand(50, 100)
        ax.contourf(X, Y, z, cmap="viridis")
    
    def plot_2(ax):
        ax.scatter([2, 5], [1, 4], zorder=2.5, color="r")
    
    
    fig_1, ax_1 = plt.subplots()
    fig_2, ax_2 = plt.subplots()
    
    plot_1(ax_1)
    plot_1(ax_2)
    plot_2(ax_2)
    

    This will plot the two figures with the contourf but only one with the scatter.