Search code examples
algorithmpython-3.5dimensionssubplot

Determine subplot square dimensions from total number of plots


I'm trying to figure out how to calculate subplot dimensions when I know the total number of plots I need and I'd like the arrangement to be a square (with possibly a few empty subplots).

For example if I need 22 subplots then I would make a grid of 5x5 for a total of 25 subplots, then just leave three of them empty.

So I guess I'm looking for an algorithm where I'd input 22 and get out 5, for example. Anyone know of a short way to do this in python (maybe lambda function if possible)?

(Also open to other alternatives or pre-made solutions for doing this, I'm doing multiple subplot matrices for a dictionary of pandas dataframes)


Solution

  • This should work for what you're trying to do. I haven't tried anything with a lambda function but I doubt it would be difficult to modify this. There won't be any empty plots because the algorithm stops once it's out of values to plot.

    I broke up the dictionary into key and value lists because I was originally working with lists when I wrote this. Everything up to the try clause would work without converting your values to a list. If you wanted to fill in with empty plots rather than using the somewhat hack-y break_test bit, you can put all of your code for the subplots inside a try clause.

    Weird break version:

    fig = plt.figure()
    
    # Makes organizing the plots easier
    key_list, val_list = [k, v for k, v in dict.getitems()]
    
    # We take advantage of the fact that int conversions always round down
    floor = int(np.sqrt(len(val_list))
    
    # If the number of plots is a perfect square, we're done.
    # Otherwise, we take the next highest perfect square to build our subplots
    if floor ** 2 == len(val_list):
        sq_chk = floor
    else:
        sq_chk = floor + 1
    
    plot_count = 0
    
    # The try/except makes sure we can gracefully stop building plots once 
    # we've exhausted our dictionary values.
    for i in range(sq_chk):
        for j in range(sq_chk):
            try:
                break_test = val_list[plot_count]
            except:
                break
    
            ax = fig.add_subplot(sq_chk, sq_chk, plot_count + 1)
            ax.set_title(key_list[plot_count])
    
            ...
            # Whatever you want to do with your plots
            ...
    
            plot_count += 1
    
    plt.show()
    

    No break version:

    fig = plt.figure()
    key_list, val_list = [k, v for k, v in dict.getitems()]
    
    floor = int(np.sqrt(len(dict))
    
    if floor ** 2 == len(dict):
        sq_chk = floor
    else:
        sq_chk = floor + 1
    
    plot_count = 0
    
    # Everything from the original snippet should be nested in the try clause
    for i in range(sq_chk):
        for j in range(sq_chk):
            try:
    
                ax = fig.add_subplot(sq_chk, sq_chk, plot_count + 1)
                ax.set_title(key_list[plot_count])
    
                ...
                # Whatever you want to do with your plots
                ...
    
                plot_count += 1
    
            except:
                plot_count +=1
    
    plt.show()