Search code examples
python-3.xmatplotlibfiguresubplotmultiple-axes

How to subplot two alternate x scales and two alternate y scales for more than one subplot?


I am trying to make a 2x2 subplot, with each of the inner subplots consisting of two x axes and two y axes; the first xy correspond to a linear scale and the second xy correspond to a logarithmic scale. Before assuming this question has been asked before, the matplotlib docs and examples show how to do multiple scales for either x or y but not both. This post on stackoverflow is the closest thing to my question, and I have attempted to use this idea to implement what I want. My attempt is below.

Firstly, we initialize data, ticks, and ticklabels. The idea is that the alternate scaling will have the same tick positions with altered ticklabels to reflect the alternate scaling.

import numpy as np
import matplotlib.pyplot as plt

# xy data (global)
X = np.linspace(5, 13, 9, dtype=int)
Y = np.linspace(7, 12, 9)

# xy ticks for linear scale (global)
dtick = dict(X=X, Y=np.linspace(7, 12, 6, dtype=int))

# xy ticklabels for linear and logarithmic scales (global)
init_xt = 2**dtick['X']
dticklabel = dict(X1=dtick['X'], Y1=dtick['Y']) # linear scale
dticklabel['X2'] = ['{}'.format(init_xt[idx]) if idx % 2 == 0 else '' for idx in range(len(init_xt))] # log_2 scale
dticklabel['Y2'] = 2**dticklabel['Y1'] # log_2 scale

Borrowing from the linked SO post, I will plot the same thing in each of the 4 subplots. Since similar methods are used for both scalings in each subplot, the method is thrown into a for-loop. But we need the row number, column number, and plot number for each.

# 2x2 subplot
# fig.add_subplot(row, col, pnum); corresponding iterables = (irows, icols, iplts)
irows = (1, 1, 2, 2)
icols = (1, 2, 1, 2)
iplts = (1, 2, 1, 2)
ncolors = ('red', 'blue', 'green', 'black')

Putting all of this together, the function to output the plot is below:

def initialize_figure(irows, icols, iplts, ncolors, figsize=None):
    """ """
    fig = plt.figure(figsize=figsize)
    for row, col, pnum, color in zip(irows, icols, iplts, ncolors):
        ax1 = fig.add_subplot(row, col, pnum) # linear scale
        ax2 = fig.add_subplot(row, col, pnum, frame_on=False) # logarithmic scale ticklabels
        ax1.plot(X, Y, '-', color=color)
        # ticks in same positions
        for ax in (ax1, ax2):
            ax.set_xticks(dtick['X'])
            ax.set_yticks(dtick['Y'])
        # remove xaxis xtick_labels and labels from top row
        if row == 1:
            ax1.set_xticklabels([])
            ax2.set_xticklabels(dticklabel['X2'])
            ax1.set_xlabel('')
            ax2.set_xlabel('X2', color='gray')
        # initialize xaxis xtick_labels and labels for bottom row
        else:
            ax1.set_xticklabels(dticklabel['X1'])
            ax2.set_xticklabels([])
            ax1.set_xlabel('X1', color='black')
            ax2.set_xlabel('')
        # linear scale on left
        if col == 1:
            ax1.set_yticklabels(dticklabel['Y1'])
            ax1.set_ylabel('Y1', color='black')
            ax2.set_yticklabels([])
            ax2.set_ylabel('')
        # logarithmic scale on right
        else:
            ax1.set_yticklabels([])
            ax1.set_ylabel('')
            ax2.set_yticklabels(dticklabel['Y2'])
            ax2.set_ylabel('Y2', color='black')

        ax1.tick_params(axis='x', colors='black')
        ax1.tick_params(axis='y', colors='black')
        ax2.tick_params(axis='x', colors='gray')
        ax2.tick_params(axis='y', colors='gray')
        ax1.xaxis.tick_bottom()
        ax1.yaxis.tick_left()
        ax1.xaxis.set_label_position('top')
        ax1.yaxis.set_label_position('right')
        ax2.xaxis.tick_top()
        ax2.yaxis.tick_right()
        ax2.xaxis.set_label_position('top')
        ax2.yaxis.set_label_position('right')
        for ax in (ax1, ax2):
            ax.set_xlim([4, 14])
            ax.set_ylim([6, 13])
    fig.tight_layout()
    plt.show()
    plt.close(fig)

Calling initialize_figure(irows, icols, iplts, ncolors) produces the figure below.

enter image description here

I am applying the same xlim and ylim so I do not understand why the subplots are all different sizes. Also, the axis labels and axis ticklabels are not in the specified positions (since fig.add_subplot(...) indexing starts from 1 instead of 0.

What is my mistake and how can I achieve the desired result?

(In case it isn't clear, I am trying to put the xticklabels and xlabels for the linear scale on the bottom row, the xticklabels and xlabels for the logarithmic scale on the top row, the 'yticklabelsandylabelsfor the linear scale on the left side of the left column, and the 'yticklabels and ylabels for the logarithmic scale on the right side of the right column. The color='black' kwarg corresponds to the linear scale and the color='gray' kwarg corresponds to the logarithmic scale.)


Solution

  • The irows and icols lists inn the code do not serve any purpose. To create 4 subplots in a 2x2 grid you would loop over the range(1,5),

    for pnum in range(1,5):
        ax1 = fig.add_subplot(2, 2, pnum)
    

    This might not be the only problem in the code, but as long as the subplots aren't created correctly it's not worth looking further down.