Search code examples
pythonmatplotlibplot

Is there a way to know the type of plot in matplotlib?


I want to automate the look of my plots. My function takes an axes object created from the plt.subplots() method. However, I'm having trouble automating the size of the spines, as it seems to vary a lot depending if it is a bar plot, horizontal bar plot, scatter plot, etc.

So far, I tried using ax.get_children() and looking for Rectangle's to determine if its a bar plot, but I cannot identify if it's a horizontal or vertical bar plot.

Is there an attribute or parameter in the axes or fig object that can help determine the type of plot? or an easy way to determine the type of plot?

Thanks!


Solution

  • Not a direct way but maybe a workaround:

    When you assign the plot to a variable plot_1 = ax.plot(x, y, linewidth=2.0) the plot gets plotted anyway and the variable contains a list or matplotlib container with some info about the plot.

    Printed those variables based on 10 examples taken from the matplotlib plot types HP:
    (the full code to get this printout is added below)

    plot:
             [<matplotlib.lines.Line2D object at 0x000001F00CF91B80>]
    scatter:
             <matplotlib.collections.PathCollection object at 0x000001F00CF91E50>
    bar:
             <BarContainer object of 8 artists>
    step:
             [<matplotlib.lines.Line2D object at 0x000001F00D17B250>]
    stem:
             <StemContainer object of 3 artists>
    hist:
             (array([ 6.,  9., 32., 47., 53., 37., 10.,  6.]), array([0.34774335, 
              1.2783365 , 2.20892966, 3.13952281, 4.07011596,
              5.00070911, 5.93130226, 6.86189541, 7.79248856]), <BarContainer object of 8 artists>)
    boxplot:
             {'whiskers': [<matplotlib.lines.Line2D object at 0x000001F00D2C3D60>, 
              <matplotlib.lines.Line2D object at 0x000001F00D2C3EE0>, <matplotlib.lines.Line2D object 
              at 0x000001F00D2D4D00>, <matplotlib.lines.Line2D object at 0x000001F00D2D4FD0>, 
              <matplotlib.lines.Line2D object at 0x000001F00D2E2D90>, <matplotlib.lines.Line2D object 
              at 0x000001F00D2EF0A0>], 'caps': [<matplotlib.lines.Line2D object at 0x000001F00D2D4220>,
              <matplotlib.lines.Line2D object at 0x000001F00D2D44F0>, <matplotlib.lines.Line2D object 
              at 0x000001F00D2E22E0>, <matplotlib.lines.Line2D object at 0x000001F00D2E25B0>, 
              <matplotlib.lines.Line2D object at 0x000001F00D2EF370>, <matplotlib.lines.Line2D object 
              at 0x000001F00D2EF640>], 'boxes': [<matplotlib.patches.PathPatch object
              at 0x000001F00D2C39A0>, <matplotlib.patches.PathPatch object 
              at 0x000001F00D2D4970>, <matplotlib.patches.PathPatch object 
              at 0x000001F00D2E2A30>], 'medians': [<matplotlib.lines.Line2D object 
              at 0x000001F00D2D47C0>, <matplotlib.lines.Line2D object 
              at 0x000001F00D2E2880>, <matplotlib.lines.Line2D object 
              at 0x000001F00D2EF910>], 'fliers': [], 'means': []}
    pie:
             ([<matplotlib.patches.Wedge object at 0x000001F00D3592E0>, <matplotlib.patches.Wedge 
              object at 0x000001F00D3597C0>, <matplotlib.patches.Wedge object 
              at 0x000001F00D359CA0>, <matplotlib.patches.Wedge object 
              at 0x000001F00D3641C0>], [Text(7.138486499000184, 
              5.019756096129642, ''), Text(5.019756022668064, 
              7.1384865228692975, ''), Text(0.8615134293924815, 
              5.019755875744905, ''), Text(5.019756463437526, 
              0.861513620345407, '')])
    contour:
             <matplotlib.contour.QuadContourSet object at 0x000001F00D381910>
    tricontour:
             <matplotlib.tri.tricontour.TriContourSet object at 0x000001F00D205910>
    

    For most plot types you should be able to manually match a word that can be assigned to a type.

    step and plot are however examples that can't be distinguished this way (probably there are others as well - note that this isn't a full list).

    For the hard to distinguish plots you may also manually add info to the lists (for the matplotlib containers - see e.g. print(type(plot_5)) probably another add. workaround is needed), e.g.:

    plot_4.insert(0, 'step')
    

    Full code for reference:

    import matplotlib.pyplot as plt
    import numpy as np
    plt.style.use('_mpl-gallery')
    
    ### "plot"
    # make data
    x = np.linspace(0, 10, 100)
    y = 4 + 2 * np.sin(2 * x)
    
    # plot
    fig1, ax = plt.subplots()
    
    plot_1 = ax.plot(x, y, linewidth=2.0)
    
    ax.set(xlim=(0, 8), xticks=np.arange(1, 8),
           ylim=(0, 8), yticks=np.arange(1, 8))
    
    plt.show()
    print(f"plot: {plot_1}")
    
    ### "scatter"
    # make the data
    np.random.seed(3)
    x = 4 + np.random.normal(0, 2, 24)
    y = 4 + np.random.normal(0, 2, len(x))
    # size and color:
    sizes = np.random.uniform(15, 80, len(x))
    colors = np.random.uniform(15, 80, len(x))
    
    # plot
    fig, ax = plt.subplots()
    
    plot_2 = ax.scatter(x, y, s=sizes, c=colors, vmin=0, vmax=100)
    
    ax.set(xlim=(0, 8), xticks=np.arange(1, 8),
           ylim=(0, 8), yticks=np.arange(1, 8))
    
    plt.show()
    print(f"scatter: {plot_2}")
    
    ### "bar"
    # make data:
    np.random.seed(3)
    x = 0.5 + np.arange(8)
    y = np.random.uniform(2, 7, len(x))
    
    # plot
    fig, ax = plt.subplots()
    
    plot_3 = ax.bar(x, y, width=1, edgecolor="white", linewidth=0.7)
    
    ax.set(xlim=(0, 8), xticks=np.arange(1, 8),
           ylim=(0, 8), yticks=np.arange(1, 8))
    
    plt.show()
    print(f"bar: {plot_3}")
    
    ### "step"
    # make data
    np.random.seed(3)
    x = 0.5 + np.arange(8)
    y = np.random.uniform(2, 7, len(x))
    
    # plot
    fig, ax = plt.subplots()
    
    plot_4 = ax.step(x, y, linewidth=2.5)
    
    ax.set(xlim=(0, 8), xticks=np.arange(1, 8),
           ylim=(0, 8), yticks=np.arange(1, 8))
    
    plt.show()
    print(f"step: {plot_4}")
    
    ### "stem"
    # make data
    np.random.seed(3)
    x = 0.5 + np.arange(8)
    y = np.random.uniform(2, 7, len(x))
    
    # plot
    fig, ax = plt.subplots()
    
    plot_5 = ax.stem(x, y)
    
    ax.set(xlim=(0, 8), xticks=np.arange(1, 8),
           ylim=(0, 8), yticks=np.arange(1, 8))
    
    plt.show()
    print(f"stem: {plot_5}")
    
    ### "hist"
    # make data
    np.random.seed(1)
    x = 4 + np.random.normal(0, 1.5, 200)
    
    # plot:
    fig, ax = plt.subplots()
    
    plot_6 = ax.hist(x, bins=8, linewidth=0.5, edgecolor="white")
    
    ax.set(xlim=(0, 8), xticks=np.arange(1, 8),
           ylim=(0, 56), yticks=np.linspace(0, 56, 9))
    
    plt.show()
    print(f"hist: {plot_6}")
    
    ### "boxplot"
    # make data:
    np.random.seed(10)
    D = np.random.normal((3, 5, 4), (1.25, 1.00, 1.25), (100, 3))
    
    # plot
    fig, ax = plt.subplots()
    plot_7 = ax.boxplot(D, positions=[2, 4, 6], widths=1.5, patch_artist=True,
                    showmeans=False, showfliers=False,
                    medianprops={"color": "white", "linewidth": 0.5},
                    boxprops={"facecolor": "C0", "edgecolor": "white",
                              "linewidth": 0.5},
                    whiskerprops={"color": "C0", "linewidth": 1.5},
                    capprops={"color": "C0", "linewidth": 1.5})
    
    ax.set(xlim=(0, 8), xticks=np.arange(1, 8),
           ylim=(0, 8), yticks=np.arange(1, 8))
    
    plt.show()
    print(f"boxplot: {plot_7}")
    
    ### "pie"
    # make data
    x = [1, 2, 3, 4]
    colors = plt.get_cmap('Blues')(np.linspace(0.2, 0.7, len(x)))
    
    # plot
    fig, ax = plt.subplots()
    plot_8 = ax.pie(x, colors=colors, radius=3, center=(4, 4),
           wedgeprops={"linewidth": 1, "edgecolor": "white"}, frame=True)
    
    ax.set(xlim=(0, 8), xticks=np.arange(1, 8),
           ylim=(0, 8), yticks=np.arange(1, 8))
    
    plt.show()
    print(f"pie: {plot_8}")
    
    ### "contuor"
    plt.style.use('_mpl-gallery-nogrid')
    
    # make data
    X, Y = np.meshgrid(np.linspace(-3, 3, 256), np.linspace(-3, 3, 256))
    Z = (1 - X/2 + X**5 + Y**3) * np.exp(-X**2 - Y**2)
    levels = np.linspace(np.min(Z), np.max(Z), 7)
    
    # plot
    fig, ax = plt.subplots()
    
    plot_9 = ax.contour(X, Y, Z, levels=levels)
    
    plt.show()
    print(f"contour:\n         {plot_9}")
    
    ### "tricontuor"
    # make data:
    np.random.seed(1)
    x = np.random.uniform(-3, 3, 256)
    y = np.random.uniform(-3, 3, 256)
    z = (1 - x/2 + x**5 + y**3) * np.exp(-x**2 - y**2)
    levels = np.linspace(z.min(), z.max(), 7)
    
    # plot:
    fig, ax = plt.subplots()
    
    ax.plot(x, y, 'o', markersize=2, color='lightgrey')
    plot_10 = ax.tricontour(x, y, z, levels=levels)
    
    ax.set(xlim=(-3, 3), ylim=(-3, 3))
    
    plt.show()
    print(f"tricontour:\n         {plot_10}")
    
    print(f"plot:\n         {plot_1}")
    print(f"scatter:\n         {plot_2}")
    print(f"bar:\n         {plot_3}")
    print(f"step:\n         {plot_4}")
    print(f"stem:\n         {plot_5}")
    print(f"hist:\n         {plot_6}")
    print(f"boxplot:\n         {plot_7}")
    print(f"pie:\n         {plot_8}")
    print(f"contour:\n         {plot_9}")
    print(f"tricontour:\n         {plot_10}")