Search code examples
pythonnumpymatplotlibsurfacematplotlib-3d

How does the x,y,z in `plot_surface` work


I am trying to figure out how to work with the input data for pyplot's plot_surface, as it is quite confusing. For a start the y-axis does not represent height as I am used to with geometrics. There is some documentation for it, but it doesn't make sense to me.

The general format of ax.plot_surface() is below.

>> ax.plot_surface(X, Y, Z)

Where X and Y are 2D array of x and y points and Z is a 2D array of heights.

I noticed in examples the data is commonly build something like:

import numpy as np
import matplotlib.pyplot as plt

tx = np.linspace(0, 1, nx)
ty = np.linspace(0, 1, ny)
X, Y = np.meshgrid(tx, ty)
Z = ((X**2-1)**2) # or some other pseudo-random generated data

plt.figure(figsize=(10,10))
plt.plot_surface(X, Y, Z)
plt.show()

So from what I could figure out all the variables are 2D-arrays where:

  • X is an array mapping x-coords against a row of y-coords.
  • Y is an array mapping y-coords against a column of y-coords.
  • Z is a 2D-array containing all the height values. But I am not sure if it is x-first or y-first.

So not sure if these assumptions are correct.

Does anyone know a place where these things are described in a better understandable way, or could anyone explain/dumb this down for me a bit?


Solution

  • The 3 arrays X,Y,Z relate to each other and the resulting plot as follows:

    • X,Y,Z all need to have the same (2-dimensional) shape
    • At any index (i,j): X[i,j] specifies an x-coordinate, Y[i,j] specifies a y-coordinate, and Z[i,j] is the height of the surface over the xy-point (X[i,j], Y[i,j])
    • For compatibly set up X and Y arrays and a vectorized function f, f(X,Y) will produce a suitable Z-array
    • The plotter will build the surface by connecting points that neighbor each other within the array. Every "square" set of subarays X[i:i+2,j:j+2], Y[i:i+2, j:j+2], Z[i:i+2, j:j+2] corresponds to a quadrilateral within the resulting mesh of the surface. The coordinates within the array X[i:i+2,j:j+2], Y[i:i+2, j:j+2] can be seen as the "shadow" of this quadrilateral on the xy-plane

    The mesharray command is designed to exploit this in an intuitive fashion: within each of the arrays, moving "horizontally" within the array corresponds to changing the x-coordinate while holding a y-coordinate constant. Conversely, moving "vertically" within the array corresponds to changing the y-coordinate while holding the x-coordinate constant.

    While the correspondence between proximity in the array to proximity in space is important, this particular correspondence between movement within the array and movement in 3-D space is purely a convention followed for the convenience and intuition of the user.

    One (possibly counterintuitive) result of all this is that it is possible to permute the X,Y arrays (and hence the resulting Z-array) without changing the resulting surface. As a simple example, if you apply transposes

    X, Y = X.T, Y.T
    

    then the resulting figure will be the same, but horizontal movement in the array now corresponds to a change in the y-coordinate and vice versa.

    Another result is that, as you suspected, it is possible to plot a surface over non-rectangular regions! Here's a quick demo.

    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib import cm
    
    nx = 11
    ny = 11
    
    tx = np.linspace(-1, 1, nx)
    ty = np.linspace(-1, 1, ny)
    X, Y = np.meshgrid(tx, ty)
    
    #
    # Reduce X,Y,Z to a diagonal strip
    #
    def diag_strip(M):
        return np.vstack([np.diag(M,k=-1),np.diag(M)[:-1],np.diag(M,k=1)])
    X = diag_strip(X)
    Y = diag_strip(Y)
    
    #
    # evaluate Z-array
    #
    Z = 1/(1 + X**2 + Y**2/2 + X*Y/2) 
    
    fig, ax = plt.subplots(1,2, figsize = (20,10), subplot_kw={"projection": "3d"})
    ax[0].plot_surface(X, Y, Z, cmap = cm.jet)
    ax[1].view_init(elev = 90, azim = 0)
    ax[1].plot_surface(X, Y, Z, cmap = cm.jet)
    

    If we comment out the third-to-last block (which reduces X,Y to coordinates over a diagonal strip), then we end up with the expected plot over [-1,1] x [-1,1].

    enter image description here

    The plot on the left is the standard surface plot that you would expect, the plot on the right is the same plot but presented from a "bird's eye" view (the camera looks "down" the z-axis) and the x,y axes oriented in the standard horizontal and vertical configurations.

    On the other hand, running the code as is results in the following:

    enter image description here

    As can be clearly seen in the figure on the right, the quadrilaterals of the mesh no longer have rectangular "shadows" on the xy-plane.

    The reduced X and Y arrays are as follows:

    [[-1.  -0.8 -0.6 -0.4 -0.2  0.   0.2  0.4  0.6  0.8]
     [-1.  -0.8 -0.6 -0.4 -0.2  0.   0.2  0.4  0.6  0.8]
     [-0.8 -0.6 -0.4 -0.2  0.   0.2  0.4  0.6  0.8  1. ]]
    
    [[-0.8 -0.6 -0.4 -0.2  0.   0.2  0.4  0.6  0.8  1. ]
     [-1.  -0.8 -0.6 -0.4 -0.2  0.   0.2  0.4  0.6  0.8]
     [-1.  -0.8 -0.6 -0.4 -0.2  0.   0.2  0.4  0.6  0.8]]
    

    The first row of each array corresponds to the parallel line "above" the true diagonal, the second corresponds to the diagonal of the rectangular grid, and the third corresponds to the parallel line "below" the true diagonal.

    On the other hand, mixing around the X and Y values too much will mess with the way that the quadrilaterals of the mesh are constructed. For instance, diag_strip lines to

    X = diag_strip(X).reshape([-1,3]).T
    Y = diag_strip(Y).reshape([-1,3]).T
    

    produces the following:

    enter image description here

    It's marginally easier to see what's going on in the above if we do the above plot with a low-opacity surface plot together with a wireframe plot:

    enter image description here

    Here's what the X and Y matrices look like now:

    [[-1.  -0.4  0.2  0.8 -0.6  0.   0.6 -0.6  0.   0.6]
     [-0.8 -0.2  0.4 -1.  -0.4  0.2  0.8 -0.4  0.2  0.8]
     [-0.6  0.   0.6 -0.8 -0.2  0.4 -0.8 -0.2  0.4  1. ]]
    
    [[-0.8 -0.2  0.4  1.  -0.6  0.   0.6 -0.8 -0.2  0.4]
     [-0.6  0.   0.6 -1.  -0.4  0.2  0.8 -0.6  0.   0.6]
     [-0.4  0.2  0.8 -0.8 -0.2  0.4 -1.  -0.4  0.2  0.8]]
    

    Speaking of surfaces over non-rectangular regions, here are two ways to plot a surface over a circular region. One approach is to use a booalean mask to filter the points down to those whose (x,y) coordinates lie within the desired area. Note that setting an entry of the Z-array (or of the X-array or Y-array for that matter) will cause a point to not be plotted. Here is a quick demo

    nx = 51
    ny = 51
    
    tx = np.linspace(-1, 1, nx)
    ty = np.linspace(-1, 1, ny)
    X, Y = np.meshgrid(tx, ty)
    
    Z = X**2 - Y**2
    
    fig, ax = plt.subplots(2,2, figsize = (20,20), subplot_kw={"projection": "3d"})
    ax[0,0].plot_surface(X, Y, Z, cmap = cm.jet)
    ax[0,1].view_init(elev = 90, azim = 0)
    ax[0,1].plot_surface(X, Y, Z, cmap = cm.jet)
    
    mask = (X**2 + Y**2) > 1
    Z[mask] = np.nan
    
    ax[1,0].plot_surface(X, Y, Z, cmap = cm.jet)
    ax[1,1].view_init(elev = 90, azim = 0)
    ax[1,1].plot_surface(X, Y, Z, cmap = cm.jet)
    

    The resulting plots:

    enter image description here

    This approach can be easily adapted to other kinds of regions. Unfortunately, this can lead to undesirable artifacts near the boundary. An approach that I prefer for circular regions in particular is to set up the mesh using polar coordinates:

    nr = 11
    nth = 31
    
    r_vals = np.linspace(0, 1, nr)
    th_vals = np.linspace(0, 2*np.pi, nth)
    R, Th = np.meshgrid(r_vals, th_vals)
    X, Y = R*np.cos(Th), R*np.sin(Th)
    
    Z = X**2 - Y**2
    
    fig, ax = plt.subplots(1,2, figsize = (20,20), subplot_kw={"projection": "3d"})
    ax[0].plot_surface(X, Y, Z, cmap = cm.jet)
    ax[1].view_init(elev = 90, azim = 0)
    ax[1].plot_surface(X, Y, Z, cmap = cm.jet)
    

    The result:

    enter image description here

    Because of the way that the meshgrid function works, moving horizontally within the resulting X,Y,Z corresponds to a change in "r" (distance of the (x,y) coordinate from (0,0)) and moving vertically corresponds to a change in "theta" (counterclockwise angle from the positive x-axis).

    Here's a script to plot the same thing with a translucent surface and wireframe:

    nr = 11
    nth = 31
    
    r_vals = np.linspace(0, 1, nr)
    th_vals = np.linspace(0, 2*np.pi, nth)
    R, Th = np.meshgrid(r_vals, th_vals)
    X,Y = R*np.cos(Th), R*np.sin(Th)
    
    Z = X**2 - Y**2
    
    fig, ax = plt.subplots(1,2, figsize = (20,20), subplot_kw={"projection": "3d"})
    ax[0].plot_surface(X, Y, Z, cmap = cm.jet, alpha = 0.1)
    ax[0].plot_wireframe(X, Y, Z)
    ax[1].view_init(elev = 90, azim = 0)
    ax[1].plot_surface(X, Y, Z, cmap = cm.jet, alpha = 0.1)
    ax[1].plot_wireframe(X, Y, Z)
    

    The result:

    enter image description here