Search code examples
pythonmatplotlibplotlegend

tabular legend layout for matplotlib


I have a plot with 9 lines, representing datasets with two varying parameters, say f_11, f_12, f_13, ..., f_33. To make the plot (a bit) clearer, I encode the first parameter as the color of the line and the second one as the linestyle (so f_11 is red & dashed, f12 is red & dotted, f21 is green & dashed, f22 is green & dotted, etc.). So, for the legend, I would like to make a 3x3 table, looking like

       | value1 | value2 | value3
---------------------------------
value1 |
value2 |    <artists go there>
value3 |

Is there any way I can make this with matplotlib? An idea would be to make this box with LaTeX, but I need a way to plot the legend artists at the right position.

Thanks!

(crosspossted from matplotlib-users)


Solution

  • Not a very easy question but I figured it out. The trick I use is to initialize an empty rectangle which acts as a handle. These additional empty handles are used to construct the table. I get rid of any excessive space using handletextpad:

    import numpy
    import pylab
    import matplotlib.pyplot as plt
    from matplotlib.patches import Rectangle
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    
    im1 ,= ax.plot(range(10), pylab.randn(10), "r--")
    im2 ,= ax.plot(range(10), pylab.randn(10), "g--")
    im3 ,= ax.plot(range(10), pylab.randn(10), "b--")
    im4 ,= ax.plot(range(10), pylab.randn(10), "r.")
    im5 ,= ax.plot(range(10), pylab.randn(10), "g.")
    im6 ,= ax.plot(range(10), pylab.randn(10), "b.")
    im7 ,= ax.plot(range(10), pylab.randn(10), "r^")
    im8 ,= ax.plot(range(10), pylab.randn(10), "g^")
    im9 ,= ax.plot(range(10), pylab.randn(10), "b^")
    
    # create blank rectangle
    extra = Rectangle((0, 0), 1, 1, fc="w", fill=False, edgecolor='none', linewidth=0)
    
    #Create organized list containing all handles for table. Extra represent empty space
    legend_handle = [extra, extra, extra, extra, extra, im1, im2, im3, extra, im4, im5, im6, extra, im7, im8, im9]
    
    #Define the labels
    label_col_1 = [r"$f_{i,j}$", r"$i = 1$", r"$i = 2$", r"$i = 3$"]
    label_j_1 = [r"$j = 1$"]
    label_j_2 = [r"$j = 2$"]
    label_j_3 = [r"$j = 3$"]
    label_empty = [""]
    
    #organize labels for table construction
    legend_labels = numpy.concatenate([label_col_1, label_j_1, label_empty * 3, label_j_2, label_empty * 3, label_j_3, label_empty * 3])
    
    #Create legend
    ax.legend(legend_handle, legend_labels, 
              loc = 9, ncol = 4, shadow = True, handletextpad = -2)
    
    plt.show()
    

    enter image description here