Search code examples
pythonmatplotlibimshow

Matplotlib imshow and secondary x and y axis


Let's say I have a picture taken with a sensor where the pixel size is 1mm. I would like to show the image with imshow: the main axes should show the pixel while the secondary axes should show the mm.

frassino.png is the following picture

enter image description here

from matplotlib import pyplot as plt
import cv2
import numpy as np

a = cv2.imread('frassino.png')
fig,ax = plt.subplots(1)
ax.imshow(a,aspect='equal')
ax.set_xlabel('pixel')
ax.set_ylabel('pixel')
ax.figure.savefig('1.png')

1.png is the following picture, all is fine (I need the pixel to be square and so I add the argument aspect='equal'.

enter image description here

Now I add a secondary y axis:

v2 = ax.twinx()
v2.set_yticks(np.linspace(0,48,12))
v2.set_xlabel('mm')
ax.figure.savefig('2.png')

2.png is the following picture and I have two problems: first, the image is cropped and the upper part of the tree, like the foreground grass, is not visible; second, the mm label is truncated.

enter image description here

Now I add the secondary x axis:

h2 = ax.twiny()
h2.set_xticks(np.linspace(0,64,8))
h2.set_xlabel('mm')
ax.figure.savefig('3.png')

The following picture is 3.png, the mm label is there but the image is still cropped.

enter image description here

How can the crop be avoided?

How can the y mm label be fixed?


Solution

  • Based on an example found in the documentation, it seems the solution is to use secondary axes rather than twin axes. That will prevent datalim conflicts between the axes holding the image and the other ones that are only there to hold a different scale and ticks but no artist.

    from matplotlib import pyplot as plt
    import cv2
    import numpy as np
    
    a = cv2.imread('frassino.png')
    fig, ax = plt.subplots(1)
    ax.imshow(a, aspect='equal')
    ax.set_xlabel('pixel')
    ax.set_ylabel('pixel')
    ax.figure.savefig('1.png')
    
    def px_to_mm(values):
        return values/10
    
    def mm_to_px(values):
        return 10*values
    
    v2 = ax.secondary_yaxis('right', functions=(px_to_mm, mm_to_px))
    v2.set_yticks(np.linspace(0, 64, 12, endpoint=True))
    v2.set_ylabel('mm')
    ax.figure.savefig('2.png')
    
    h2 = ax.secondary_xaxis('top', functions=(px_to_mm, mm_to_px))
    h2.set_xticks(np.linspace(0, 48, 8, endpoint=True))
    h2.set_xlabel('mm')
    ax.figure.savefig('3.png')
    

    enter image description here