Search code examples
matplotlibsubfigure

Matplotlib - Split graph creation into multiple functions


In order to create figures with some same graphs, I would like to define a function per group of graph. These should be called depending on the subfigure provided in order to have these graphs at the right location. Consequently, I would liek to split this code below into separate functions as a code like the one provided after this one.

fig = plt.figure(constrained_layout=True, figsize=(10, 8))

# create top/bottom subfigs
(subfig_t, subfig_b) = fig.subfigures(2, 1, hspace=0.05, height_ratios=[1, 3])

# put ax0 in top subfig
ax0 = subfig_t.subplots()
ax0.set_title('ax0')
subfig_t.supxlabel('xlabel0')

# create left/right subfigs nested in bottom subfig
(subfig_bl, subfig_br) = subfig_b.subfigures(1, 2, wspace=0.1, width_ratios=[3, 1])

# put ax1-ax3 in gridspec of bottom-left subfig
gs = subfig_bl.add_gridspec(nrows=1, ncols=9)
ax1 = subfig_bl.add_subplot(gs[0, :1])
ax2 = subfig_bl.add_subplot(gs[0, 1:6], sharey=ax1)
ax3 = subfig_bl.add_subplot(gs[0, 6:], sharey=ax1)
ax1.set_title('ax1')
ax2.set_title('ax2')
ax3.set_title('ax3')
ax2.get_yaxis().set_visible(False)
ax3.get_yaxis().set_visible(False)
subfig_bl.supxlabel('xlabel1-3')

# put ax4 in bottom-right subfig
ax4 = subfig_br.subplots()
ax4.set_title('ax4')
subfig_br.supxlabel('xlabel4')

Below is the code-like I would like to have, to avoid to write the same code multiple times.

fig = plt.figure(constrained_layout=True, figsize=(10, 8))

# create top/bottom subfigs
(subfig_t, subfig_b) = fig.subfigures(2, 1, hspace=0.05, height_ratios=[1, 3])
(subfig_bl, subfig_br) = subfig_b.subfigures(1, 2, wspace=0.1, width_ratios=[3, 1])

def func1(subfig_t):
# put ax0 in top subfig
ax0 = subfig_t.subplots()
ax0.set_title('ax0')
subfig_t.supxlabel('xlabel0')
return subfig_t

def func2(subfig_bl):
# put ax1-ax3 in gridspec of bottom-left subfig
gs = subfig_bl.add_gridspec(nrows=1, ncols=9)
ax1 = subfig_bl.add_subplot(gs[0, :1])
ax2 = subfig_bl.add_subplot(gs[0, 1:6], sharey=ax1)
ax3 = subfig_bl.add_subplot(gs[0, 6:], sharey=ax1)
ax1.set_title('ax1')
ax2.set_title('ax2')
ax3.set_title('ax3')
ax2.get_yaxis().set_visible(False)
ax3.get_yaxis().set_visible(False)
subfig_bl.supxlabel('xlabel1-3')
return subfig_bl

def func3(subfig_br):
# put ax4 in bottom-right subfig
ax4 = subfig_br.subplots()
ax4.set_title('ax4')
subfig_br.supxlabel('xlabel4')
return subfig_bl

def func_save(fig, OutputPath):
fig.savefig(OutputPath, dpi=300, format='png', bbox_inches='tight')

subfig_t = func1(subfig_t)
subfig_bl = func2(subfig_bl)
subfig_br = func3(subfig_br)
func_save(fig, OutputPath)

Solution

  • The functions are not defined as functions, few of the syntax changes and the code is good to run. Python syntax is quite different from other programming languages. It is very simple to learn, and even complex to understand the unknown.

    The below code will run perfectly, hope you find it useful.

    import numpy as np
    import matplotlib.pyplot as plt
    
    
    fig = plt.figure(constrained_layout=True, figsize=(10, 8))
    
    # create top/bottom subfigs
    (subfig_t, subfig_b) = fig.subfigures(2, 1, hspace=0.05, height_ratios=[1, 3])
    (subfig_bl, subfig_br) = subfig_b.subfigures(1, 2, wspace=0.1, width_ratios=[3, 1])
    
    
    def func1(subfig_t):
        # put ax0 in top subfig
        ax0 = subfig_t.subplots()
        ax0.set_title('ax0')
        subfig_t.supxlabel('xlabel0')
        return subfig_t
    
    
    def func2(subfig_bl):
        # put ax1-ax3 in gridspec of bottom-left subfig
        gs = subfig_bl.add_gridspec(nrows=1, ncols=9)
        ax1 = subfig_bl.add_subplot(gs[0, :1])
        ax2 = subfig_bl.add_subplot(gs[0, 1:6], sharey=ax1)
        ax3 = subfig_bl.add_subplot(gs[0, 6:], sharey=ax1)
        ax1.set_title('ax1')
        ax2.set_title('ax2')
        ax3.set_title('ax3')
        ax2.get_yaxis().set_visible(False)
        ax3.get_yaxis().set_visible(False)
        subfig_bl.supxlabel('xlabel1-3')
        return subfig_bl
    
    def func3(subfig_br):
        # put ax4 in bottom-right subfig
        ax4 = subfig_br.subplots()
        ax4.set_title('ax4')
        subfig_br.supxlabel('xlabel4')
        return subfig_bl
    
    def func_save(fig, OutputPath):
        fig.savefig(OutputPath, dpi=300, format='png', bbox_inches='tight')
    
    
    # Enter the path for output here
    OutputPath = "output.png"
    
    subfig_t = func1(subfig_t)
    subfig_bl = func2(subfig_bl)
    subfig_br = func3(subfig_br)
    func_save(fig, OutputPath)
    

    Happy coding :)