I want to plot a sequence of three colormaps in a 3D space, with a line crossing all the planes of the colormaps, as shown in the figure below.
https://i.sstatic.net/65yOib6B.png
To do that, I am using mpl.plot_surface to generate the planes and LinearSegmentedColormap to create a colormap that transitions from transparent to a specific color.
However, when I plot the figure, a gray grid appears on my plot. How can I remove it? Ideally, the blue shade would appear on a completely transparent plane, but a lighter color could also work.
Here is the code I used to generate the plot:
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import LinearSegmentedColormap
# Testing Data
sigma = 1.0
mu = np.linspace(0,2, 10)
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)
Z = []
for m in mu:
Z.append(np.exp(-((X - m)**2 + (Y - m)**2) / (2 * sigma**2)))
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for i in [0, 5, -1]:
cmap = LinearSegmentedColormap.from_list('custom_blue', [(1, 1, 1, 0), (0, 0, 1, 1)])
wmap = cmap(Z[i]/Z[i].max())
ax.plot_surface(mu[i] * np.ones(X.shape), X, Y,facecolors=wmap, alpha=1, antialiased=True, edgecolor='none')
loc_max_x = []
loc_max_y = []
for i in range(len(mu)):
loc_x = np.where(Z[i] == Z[i].max())[0][0]
loc_y = np.where(Z[i] == Z[i].max())[1][0]
loc_max_x.append(loc_x)
loc_max_y.append(loc_y)
ax.plot(mu, x[loc_max_x], y[loc_max_y], color='r')
ax.set_box_aspect((3.4, 1, 1))
plt.savefig('3dplot.png', dpi=300)
plt.show()
I think there's nothing you could have done better in matplotlib
, great job!
I think to solve your problem, it is better to change the library and approach your problem using plotly
.
Please see my code:
import plotly.graph_objects as go
import numpy as np
# Testing Data
sigma = 1.0
mu = np.linspace(0, 2, 10)
x = np.linspace(-5, 5, 100)
y = np.linspace(-5, 5, 100)
X, Y = np.meshgrid(x, y)
Z = []
for m in mu:
Z.append(np.exp(-((X - m)**2 + (Y - m)**2) / (2 * sigma**2)))
fig = go.Figure()
colorscale = [[0, 'rgba(255, 255, 255, 0)'], [1, 'rgba(0, 0, 255, 1)']] # colorscale = transparent to blue
#plot the surfaces
for i in [0, 5, -1]:
fig.add_trace(go.Surface(
x=mu[i] * np.ones(X.shape), y=X, z=Y, surfacecolor=Z[i],
colorscale=colorscale, cmin=0, cmax=Z[i].max(),
showscale=False, opacity=1))
#plot the line crossing the surfaces
loc_max_x = []
loc_max_y = []
for i in range(len(mu)):
loc_x = np.where(Z[i] == Z[i].max())[0][0]
loc_y = np.where(Z[i] == Z[i].max())[1][0]
loc_max_x.append(loc_x)
loc_max_y.append(loc_y)
#add the line trace
fig.add_trace(go.Scatter3d(
x=mu, y=x[loc_max_x], z=y[loc_max_y],
mode='lines', line=dict(color='red', width=5)))
fig.update_layout(scene_aspectmode='manual',
scene_aspectratio=dict(x=3.4, y=1, z=1),
scene=dict(xaxis_title='mu', yaxis_title='X', zaxis_title='Y'))
fig.show()