Search code examples
numpymatplotlibmatrixplot

Use matplotlib plot_surface to plot a matrix that contains NaN values


I would like to plot a matrix that contains a combination of float and NaN values. This is a 3D plot where X and Y are the matrix coordinates and Z is the value within the matrix.

NaN values should be ignored. It would be great if matplot would fill in the surface between float values, but OK if it wont.

This is an adaptation of the code that I have tried thus far. It should plot the 3 data points that have been assigned manually, but instead, it produces an empty 3D plot.

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(20,15))
ax = fig.add_subplot(111, projection='3d')

X0=0
Xmax=10
Y0=0
Ymax=10

Xfill,Yfill=numpy.meshgrid(range(X0,Xmax),range(Y0,Ymax))

data_matrix=numpy.full(shape=[Xmax,Ymax],fill_value=numpy.nan)

data_matrix[5,5]=3
data_matrix[1,8]=6
data_matrix[7,2]=0.5

ax.plot_surface(Xfill,Yfill, data_matrix[X0:Xmax,Y0:Ymax],color='blue',rstride=1,cstride=1)


plt.show()

Solution

  • I've shown two options I know of below; one uses a scatter plot, and the other draws a surface given an arbitary set of points.

    enter image description here

    import matplotlib.pyplot as plt
    import numpy as np
    
    fig = plt.figure(figsize=(4, 4))
    ax = fig.add_subplot(111, projection='3d')
    
    X0 = 0
    Xmax = 10
    Y0 = 0
    Ymax = 10
    
    Xfill, Yfill = np.meshgrid(range(X0, Xmax), range(Y0, Ymax))
    
    data_matrix = np.full(shape=[Xmax, Ymax], fill_value=np.nan)
    
    data_matrix[5, 5] = 3
    data_matrix[1, 8] = 6
    data_matrix[7, 2] = 0.5
    
    #Pull out the non-nan datapoints
    x_valid = np.argwhere(~np.isnan(data_matrix))[:, 0]
    y_valid = np.argwhere(~np.isnan(data_matrix))[:, 1]
    data_valid = data_matrix[x_valid, y_valid]
    
    #Scatter plot of individual points
    ax.scatter(x_valid, y_valid, data_valid, c='tab:red',
               s=60, label='scatter', depthshade=False)
    
    #Also works somewhat:
    # ax.scatter(Xfill, Yfill, data_matrix)
    
    #Overlay a surface plot that doesn't require a regular grid
    ax.plot_trisurf(x_valid, y_valid, data_valid,
                    cmap='jet', label='trisurf plot', alpha=0.7)
    

    Optional further formatting:

    # Some additional flourishes
    ax.view_init(azim=20, elev=45, roll=0)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('data')
    
    #Vertical lines from each point
    from mpl_toolkits.mplot3d.art3d import Line3DCollection
    lines_start = [(x, y, 0) for x, y in zip(x_valid, y_valid)]
    lines_end = [(x, y, z) for x, y, z in zip(x_valid, y_valid, data_valid)]
    lines = list(zip(lines_start, lines_end))
    
    ax.add_collection(Line3DCollection(lines, linewidth=3,
                      color='tab:orange', label='vertical projection'))
    plt.gcf().legend()