Search code examples
pythonmatplotlibplotscattermultiple-axes

Matplotlib: add twin y axis without using its values in the plots


This is to clarify the question title. Say you have four lists of integers, with which you want to produce a scatter plot:

a=[3,7,2,8,12,17]
b=[9,4,11,7,6,3]
c=[9,3,17,13,10,5]
d=[5,1,1,14,5,8]

You also have a function, for simplicity f(x)=1/x, that applies to all lists, so that:

from __future__ import division
a1=[1/i for i in a]
b1=[1/i for i in b]
c1=[1/i for i in c]
d1=[1/i for i in d]

My question: how to add a second y axis, knowing that the values returned by the function range from 0.06 to 1.0, without using any of the a1, b1, c1, d1 lists in the scatter plots?

What I am saying is: if you produce the following scatter plots in the traditional way, how can you then add the second y axis based on the values of a1, b1, c1, d1, without having any series using them in the plot itself?

import matplotlib.pyplot as plt
plt.scatter(a,b,c='red',label='reds')
plt.scatter(c,d,c='blue',label='blues')
plt.legend(loc='best')

This is the scatter without the second y axis: enter image description here

And this is a made up version of the same one, including the second y axis discussed so far: enter image description here

NB: This question is different from this, in that I am not trying to plot with different scales. I only want to add a second axis with the relevant values.


Solution

  • To make sure the numbers on the new axis are in the corresponding position to their inverses:

    import matplotlib.pylab as plt
    
    a=[3,7,2,8,12,17]
    b=[9,4,11,7,6,3]
    c=[9,3,17,13,10,5]
    d=[5,1,1,14,5,8]
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    
    ax.scatter(a,b,c='red',label='reds')
    ax.scatter(c,d,c='blue',label='blues')
    ax.legend(loc='best')
    ax.set_ylabel('Y')
    # make shared y axis
    axi = ax.twinx()
    # set limits for shared axis
    axi.set_ylim(ax.get_ylim())
    # set ticks for shared axis
    inverse_ticks = []
    label_format = '%.3f'
    for tick in ax.get_yticks():
        if tick != 0:
            tick = 1/tick
        inverse_ticks.append(label_format % (tick,))
    axi.set_yticklabels(inverse_ticks)
    axi.set_ylabel('1/Y')
    fig.tight_layout()
    fig.show()
    

    enter image description here

    And you can also do it for the X axis:

    # make shared x axis
    xaxi = ax.twiny()
    # set limits for shared axis
    xaxi.set_xlim(ax.get_xlim())
    # set ticks for shared axis
    inverse_ticks = []
    label_format = '%.3f'
    for tick in ax.get_xticks():
        if tick != 0:
            tick = 1/tick
        inverse_ticks.append(label_format % (tick,))
    xaxi.set_xticklabels(inverse_ticks)
    xaxi.set_xlabel('1/X')
    

    enter image description here