Search code examples
pythonmatplotlibsubplotcolorbar

How to change the height of each image grid with mpl_toolkits.axes_grid1.Imagegrid


I'm plotting an image with 4*6 grids with mpl_toolkits.axes_grid1.Imagegrid:

target = np.reshape(target, (12, 41, 81))
output = np.reshape(output, (12, 41, 81))
pred_error = target - output
# target: C x D x H x W
sfmt = ticker.ScalarFormatter(useMathText=True)
sfmt.set_powerlimits((-2, 2))
cmap = 'jet'
fig = plt.figure(1, (11, 5.5))
axes_pad = 0.1
cbar_pad = 0.1
label_size = 6
plt.rcParams["mpl_toolkits.legacy_colorbar"] = False
subplots_position = [(4,6,i) for i in range(1, 25)]
for i, subplot_i in enumerate(subplots_position):
    if i in [i for i in range(6)]+[i for i in range(12, 18)]:
        # share one colorbar
        grid = ImageGrid(fig, subplot_i,          # as in plt.subplot(111)
                          nrows_ncols=(2, 1),
                          axes_pad=axes_pad,
                          share_all=False,
                          cbar_location="right",
                          cbar_mode="single",
                          cbar_size="3%",
                          cbar_pad=cbar_pad,
                          )
        if i <6:
            data = (target[i], output[i])
        else:
            data = (target[i-6], output[i-6])
        channel = np.concatenate(data)
        vmin, vmax = np.amin(channel), np.amax(channel)
        # Add data to image grid
        for j, ax in enumerate(grid):
            im = ax.imshow(data[j], vmin=vmin, vmax=vmax, cmap=cmap)
            ax.set_axis_off()
        # ticks=np.linspace(vmin, vmax, 10)
        #set_ticks, set_ticklabels
        cbar = grid.cbar_axes[0].colorbar(im, format=sfmt)
        # cbar.ax.set_yticks((vmin, vmax))
        cbar.ax.yaxis.set_offset_position('left')
        cbar.ax.tick_params(labelsize=label_size)
        cbar.ax.toggle_label(True)

    else:
        grid = ImageGrid(fig, subplot_i,  # as in plt.subplot(111)
                          nrows_ncols=(1, 1),
                          axes_pad=axes_pad,
                        #  share_all=True,
                        #  aspect=True,
                          cbar_location="right",
                          cbar_mode="single",
                          cbar_size="6%",
                          cbar_pad=cbar_pad,
                          )
        data = [pred_error[i%12]]
        for j, ax in enumerate(grid):
            im = ax.imshow(data[j], cmap=cmap)
            ax.set_axis_off()
            ax.set_axes_locator
            cbar = grid.cbar_axes[j].colorbar(im, format=sfmt)
            grid.cbar_axes[j].tick_params(labelsize=label_size)
            grid.cbar_axes[j].toggle_label(True)

plt.tight_layout()##pad=0.25, w_pad=0.25, h_pad=0.25)
# fig.subplots_adjust(wspace=0.075, hspace=0.075)
plt.show()
plt.savefig('test.pdf', 
            dpi=300, bbox_inches='tight')

plt.show()
plt.close(fig)

target is (2,6,41,81) shape array, resized to (12,41,81) when plotting, the output is the same dimension, pred_error is the difference between them. I want to show target in the 1st and 4th rows, output in 2nd and 5th rows with the same colorbar, pred_error in 3rd and 6th rows.

The image I'm getting now: I annotated the grid boxes, the red boxes are my annotation, each box is a grid, I plot it this way so the first two rows can share the colorbar

The image I want is without the huge gaps: I want to resize the grids in the green images rows so the big gap could shrink

I know the problem is that the grids in my image are all the same size, but I didn't find the way to edit the height of those green images grids. I appreciate your help!!!


Solution

  • I end up making figures as I want (still keeping the colorbar-sharing property) just with subplots:

    target = np.random.randn(12, 41, 81)
    target[6:,:,:] *= 2
    output = np.random.randn(12, 41, 81)
    output[6:,:,:] *= 2
    
    target = np.reshape(target, (12, 41, 81))
    output = np.reshape(output, (12, 41, 81))
    pred_error = target - output
    
    fig, axes = plt.subplots(nrows=6, ncols=6, figsize=(11, 5))
    axes = axes.flat
    index = [[i, i+6] for i in range(6)] + [[i, i+6] for i in range(18, 18+6)]
    error_index = list(np.arange(12, 18)) + list(np.arange(30, 36))
    y_ind = 0
    axes_pad = 0.1
    cbar_pad = 0.1
    label_size = 6
    for ind_pair in index:
        axy = axes[ind_pair[0]]
        data = (target[y_ind], output[y_ind])
        vmin = np.min(data)
        vmax = np.max(data)
        imy = axy.imshow(data[0], vmin=vmin, vmax=vmax, cmap = 'jet')
        axout = axes[ind_pair[1]]
        imout = axout.imshow(data[1], vmin=vmin, vmax=vmax, cmap = 'jet')
        axy.set_axis_off()
        axout.set_axis_off()
        v1 = np.linspace(vmin, vmax, 7, endpoint=True)
        cbar = fig.colorbar(imout, ax=[axes[ind_pair[0]], axes[ind_pair[1]]], 
                            format='%.2f', aspect=20, shrink = 0.95)
        cbar.set_ticks(v1)
        cbar.ax.tick_params(labelsize=label_size)
        y_ind += 1
    
    e_ind = 0
    for ind_error in error_index:
        axe = axes[ind_error]
        ime = axe.imshow(pred_error[e_ind], cmap = 'jet')
        axe.set_axis_off()
        v1 = np.linspace(np.min(pred_error[e_ind]),np.max(pred_error[e_ind]), 4, endpoint=True)
        cbar = fig.colorbar(ime, ax=axes[ind_error], format='%.2f', aspect=8, shrink = 0.85)
        cbar.set_ticks(v1)
        cbar.ax.tick_params(labelsize=label_size)
                
    # plt.tight_layout()
    plt.savefig('a.pdf', 
                dpi=300, bbox_inches='tight')
    plt.show()
    plt.close(fig)
    

    The figure looks like this: enter image description here