Search code examples
pythonmatplotlibrectanglesimshow

pyplot.imshow for rectangles


I am currently using matplotlib.pyplot to visualize some 2D data:

from matplotlib import pyplot as plt
import numpy as np
A=np.matrix("1 2 1;3 0 3;1 2 0") # 3x3 matrix with 2D data
plt.imshow(A, interpolation="nearest") # draws one square per matrix entry
plt.show()

Now i moved the data from squares to rectangles, meaning i have two additional arrays, for example:

grid_x = np.array([0.0, 1.0, 4.0, 5.0]) # points on the x-axis
grid_x = np.array([0.0, 2.5, 4.0, 5.0]) # points on the y-axis

now i want a grid with rectangles:

  • upper-left corner: (grid_x[i], grid_y[j])
  • lower-right corner: (grid_x[i+1], grid_y[j+1])
  • data (color): A[i,j]

What is an easy way to plot the data on the new grid? imshow seems to to be usable, i looked at pcolormesh but its confusing with the grid as 2D array, using two matrices like np.mgrid[0:5:0.5,0:5:0.5] for the regular grid and building something similiar for the irregular one.

What is an easy way for visualization of the rectangles?


Solution

  • import matplotlib.pyplot as plt
    from matplotlib.patches import Rectangle
    import matplotlib.cm as cm
    from matplotlib.collections import PatchCollection
    import numpy as np
    
    A = np.matrix("1 2 1;3 0 3;1 2 0;4 1 2") # 4x3 matrix with 2D data
    
    grid_x0 = np.array([0.0, 1.0, 4.0, 6.7])
    grid_y0 = np.array([0.0, 2.5, 4.0, 7.8, 12.4])
    
    grid_x1, grid_y1 = np.meshgrid(grid_x0, grid_y0)
    grid_x2 = grid_x1[:-1, :-1].flat
    grid_y2 = grid_y1[:-1, :-1].flat
    widths = np.tile(np.diff(grid_x0)[np.newaxis], (len(grid_y0)-1, 1)).flat
    heights = np.tile(np.diff(grid_y0)[np.newaxis].T, (1, len(grid_x0)-1)).flat
    
    fig = plt.figure()
    ax = fig.add_subplot(111)
    ptchs = []
    for x0, y0, w, h in zip(grid_x2, grid_y2, widths, heights):
        ptchs.append(Rectangle(
            (x0, y0), w, h,
        ))
    p = PatchCollection(ptchs, cmap=cm.viridis, alpha=0.4)
    p.set_array(np.ravel(A))
    ax.add_collection(p)
    plt.xlim([0, 8])
    plt.ylim([0, 13])
    plt.show()
    

    enter image description here

    Here is another way, using image and R-tree and imshow with colorbar, you need to change the x-ticks and y-ticks (There are alot of SO Q&A about how to do it).

    from rtree import index
    import matplotlib.pyplot as plt
    import numpy as np
    
    eps = 1e-3
    
    A = np.matrix("1 2 1;3 0 3;1 2 0;4 1 2") # 4x3 matrix with 2D data
    grid_x0 = np.array([0.0, 1.0, 4.0, 6.7])
    grid_y0 = np.array([0.0, 2.5, 4.0, 7.8, 12.4])
    
    grid_x1, grid_y1 = np.meshgrid(grid_x0, grid_y0)
    grid_x2 = grid_x1[:-1, :-1].flat
    grid_y2 = grid_y1[:-1, :-1].flat
    grid_x3 = grid_x1[1:, 1:].flat
    grid_y3 = grid_y1[1:, 1:].flat
    
    fig = plt.figure()
    
    rows = 100
    cols = 200
    im = np.zeros((rows, cols), dtype=np.int8)
    grid_j = np.linspace(grid_x0[0], grid_x0[-1], cols)
    grid_i = np.linspace(grid_y0[0], grid_y0[-1], rows)
    j, i = np.meshgrid(grid_j, grid_i)
    
    i = i.flat
    j = j.flat
    
    idx = index.Index()
    
    for m, (x0, y0, x1, y1) in enumerate(zip(grid_x2, grid_y2, grid_x3, grid_y3)):
        idx.insert(m, (x0, y0, x1, y1))
    
    
    for k, (i0, j0) in enumerate(zip(i, j)):
        ind = next(idx.intersection((j0-eps, i0-eps, j0+eps, i0+eps)))
    
        im[np.unravel_index(k, im.shape)] = A[np.unravel_index(ind, A.shape)]
    plt.imshow(im)
    plt.colorbar()
    plt.show()
    

    enter image description here