Is there a way to map the color-scheme from one surface plot onto another?
For example, let's say I have:
surf_1 = ax.plot_surface(X, Y, Z, cmap='summer')
and
surf_2 = ax.plot_surface(X, Y, Z-Q, cmap='summer')
Is there a way to map the colorscheme for the surface defined by Z-Q onto the surface defined by Z? In other words, I want to visualize surf_1
, but I want its surface to take on the colors defined by surf_2
.
For context, I am trying to visualize the colors of the fluctuations of a parameter (Z) around a variable height (Q), where Q is not necessarily equal to 0.
EDIT: Is there a way I could extract the colors in surf_2
as an array, and use those colors as input colors for surf_1
? Any suggestions would be much appreciated!
You can use ScalarMappable()
function to create all colors to use as facecolors
in the two surface plots. Here is the runnable code that demonstrates the steps to achieve what you want.
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
fig, ax = plt.subplots(subplot_kw={'projection': '3d'})
fig.set_size_inches([10, 8])
# Make up data for 2 surfaces
X = np.logspace(0, np.log10(16), 50)
Y = np.linspace(3, 6, 50)
Z = np.linspace(-1, 1, 50)
# Convert to 2d arrays
Z = np.outer(Z.T, Z) # 50x50
X, Y = np.meshgrid(X, Y) # 50x50
# Make use of `ScalarMappable()` for custom color
# This use Z to get a colormap for plotting the surface
C = np.linspace(-1, 1, Z.size).reshape(Z.shape)
colormap = "summer" # 'inferno' 'plasma' 'viridis'
scmap = plt.cm.ScalarMappable(cmap=colormap)
# for clarity, 2 surfaces are separated by some z shift
zshift = 80
# Upper-surface
# Note: ax.plot_surface(X, Y, Z*X+zshift, cmap=colormap)
# is almost equivalent with this
ax.plot_surface(X, Y, Z*X+zshift, facecolors=scmap.to_rgba(Z*X+zshift), shade=False)
# `shade=False` is used to suppress 3D shading
# Lower-surface
# Also use `facecolors=scmap.to_rgba(Z*X+zshift)`
# thus, equivalent with taking color from previous surface
ax.plot_surface(X, Y, Z, facecolors=scmap.to_rgba(Z*X+zshift), shade=False)
plt.show()
The output plot: