Search code examples
pythonmatplotlibalignmentaxes

Aligning tick coordinates of top and bottom axis


I would like to build a figure that has on its bottom x-axis one set of ticks and on its top x-axis another set of ticks that is aligned with the bottom ticks. Specifically in my case these are batches and epochs. For every n batch points (not necessarily ticks) on the bottom I want an epoch tick on the top. Consider this example:

import numpy as np
import matplotlib.pyplot as plt

batches = np.arange(1,101)
epoch_ends = batches[[(i*10)-1 for i in range(1,11)]]
accuracy = np.apply_along_axis(arr=batches, axis=0, func1d=lambda x : x/len(batches))
loss = np.apply_along_axis(arr=batches, axis=0, func1d=lambda x : 1 - (x/len(batches)))

fig, ax1 = plt.subplots( nrows=1, ncols=1 )
ax2 = ax1.twinx()
ax3 = ax1.twiny()

ax1.set_xlabel('batches')
ax1.set_xticks(np.arange(1, len(batches)+1, 9))
ax1.set_ylabel('accuracy')
ax1.grid()

ax2.set_ylabel('loss')
ax2.set_yticklabels(np.linspace(3, 10, 9))

ax3.set_xlabel('epochs')
ax3.set_xticks(epoch_ends)
ax3.set_xticklabels(range(1, len(epoch_ends)+1))

acc_plt = ax1.plot(batches, accuracy, color='red', label='acc')
loss_plt = ax2.plot(batches, loss, color='blue', label='loss')

lns = acc_plt+loss_plt
labs = [l.get_label() for l in lns]
ax1.legend(lns, labs, loc=2)

plt.show()

batches and epoch_ends respectively look like this

[  1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17  18
  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35  36
  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53  54
  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71  72
  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89  90
  91  92  93  94  95  96  97  98  99 100]
[ 10  20  30  40  50  60  70  80  90 100]

So I would like the epoch tick 1 to align with batch x-coorrdiante 10, 2 with 20, etc.

But as you can see in the picture, they do not line up.

enter image description here

What do I need to change in my code to make this work?


Solution

  • Here is one way to align them. The idea is following:

    • First plot the data on lower x-axis using (ax1)
    • Then set the limit of the upper x-axis to be the same as lower x-axis using ax3.set_xlim(ax1.get_xlim())
    • Then set the ticks of the upper x-axis at locations corresponding to the lower x-axis values (10, 20, 30, ..., 90, 100)
    • Finally, rename the tick labels using ax3.set_xticklabels().

    Here is the code: I am replacing the parts which are already in your code by a comment #.

    # imports and define data and compute accuracy and loss here
    
    # Initiate figure and axis objects here
    
    ax1.set_xlabel('batches')
    ax1.set_xticks(np.arange(1, len(batches)+1, 9))
    ax1.set_ylabel('accuracy')
    ax1.grid()
    
    acc_plt = ax1.plot(batches, accuracy, color='red', label='acc')
    loss_plt = ax2.plot(batches, loss, color='blue', label='loss')
    
    ax2.set_ylabel('loss')
    ax2.set_yticklabels(np.linspace(3, 10, 9))
    
    new_tick_locations = np.arange(1, 11)*10
    new_tick_labels = range(1, 11)
    
    ax3.set_xlabel('epochs')
    ax3.set_xlim(ax1.get_xlim())
    ax3.set_xticks(new_tick_locations)
    ax3.set_xticklabels(new_tick_labels)
    
    # Set legends and show the plot
    

    enter image description here