Search code examples
pythonmatplotlibmatplotlib-3d

Plot a partially transparent plane in matplotlib


I want to plot a sequence of three colormaps in a 3D space, with a line crossing all the planes of the colormaps, as shown in the figure below.

https://i.sstatic.net/65yOib6B.png

To do that, I am using mpl.plot_surface to generate the planes and LinearSegmentedColormap to create a colormap that transitions from transparent to a specific color.

However, when I plot the figure, a gray grid appears on my plot. How can I remove it? Ideally, the blue shade would appear on a completely transparent plane, but a lighter color could also work.

Here is the code I used to generate the plot:

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

# Testing Data
sigma = 1.0
mu = np.linspace(0,2, 10)

x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)

Z = []
for m in mu:
    Z.append(np.exp(-((X - m)**2 + (Y - m)**2) / (2 * sigma**2)))

fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for i in [0, 5, -1]:
    cmap = LinearSegmentedColormap.from_list('custom_blue', [(1, 1, 1, 0), (0, 0, 1, 1)])
    wmap = cmap(Z[i]/Z[i].max())
    ax.plot_surface(mu[i] * np.ones(X.shape), X, Y,facecolors=wmap, alpha=1, antialiased=True, edgecolor='none')

loc_max_x = []
loc_max_y = []
for i in range(len(mu)):
    loc_x = np.where(Z[i] == Z[i].max())[0][0]
    loc_y = np.where(Z[i] == Z[i].max())[1][0]

    loc_max_x.append(loc_x)
    loc_max_y.append(loc_y)

ax.plot(mu, x[loc_max_x], y[loc_max_y], color='r')
ax.set_box_aspect((3.4, 1, 1))

plt.savefig('3dplot.png', dpi=300)
plt.show()

Solution

  • I think there's nothing you could have done better in matplotlib, great job!

    I think to solve your problem, it is better to change the library and approach your problem using plotly.

    Please see my code:

    import plotly.graph_objects as go
    import numpy as np
    
    
    # Testing Data
    sigma = 1.0
    mu = np.linspace(0, 2, 10)
    
    x = np.linspace(-5, 5, 100)
    y = np.linspace(-5, 5, 100)
    X, Y = np.meshgrid(x, y)
    
    Z = []
    for m in mu:
        Z.append(np.exp(-((X - m)**2 + (Y - m)**2) / (2 * sigma**2)))
    
    fig = go.Figure()
    
    colorscale = [[0, 'rgba(255, 255, 255, 0)'], [1, 'rgba(0, 0, 255, 1)']]  # colorscale = transparent to blue
    
    #plot the surfaces 
    for i in [0, 5, -1]:
        fig.add_trace(go.Surface(
            x=mu[i] * np.ones(X.shape), y=X, z=Y, surfacecolor=Z[i], 
            colorscale=colorscale, cmin=0, cmax=Z[i].max(),
            showscale=False, opacity=1))
    
    #plot the line crossing the surfaces
    loc_max_x = []
    loc_max_y = []
    for i in range(len(mu)):
        loc_x = np.where(Z[i] == Z[i].max())[0][0]
        loc_y = np.where(Z[i] == Z[i].max())[1][0]
        loc_max_x.append(loc_x)
        loc_max_y.append(loc_y)
    
    #add the line trace
    fig.add_trace(go.Scatter3d(
        x=mu, y=x[loc_max_x], z=y[loc_max_y], 
        mode='lines', line=dict(color='red', width=5)))
    
    fig.update_layout(scene_aspectmode='manual',
                      scene_aspectratio=dict(x=3.4, y=1, z=1),
                      scene=dict(xaxis_title='mu', yaxis_title='X', zaxis_title='Y'))
    
    fig.show()
    

    which results this plot: enter image description here