I am plotting a basic scatterplot in 3D using code from another SO post (Matplotlib scatter plot legend (top answer)) but want to have the points opacity relative to the 'depth' of the point as with ax.scatter
depthshade=True
.
I had to use ax.plot
because ax.scatter
doesn't appear to work well with legends on 3D plots.
I'm wondering if there is a way to get a similar aesthetic for ax.plot
.
Thanks!
It looks like you're out of luck on this, it seems plot does not have the depthshade=True
feature. I think the problem is plot does not let you set a different color (or alpha value) for each points in the way scatter does, which I guess is how depthshade is applied.
A solution is to loop over all points and set colour one by one, together with the mpl_toolkits.mplot3d.art3d.zalpha
helper function to give depth.
import mpl_toolkits
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
import numpy as np
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
n = 100
xs = np.random.rand(n)
ys = np.random.rand(n)
zs = np.random.rand(n)
color = [1,0,0,1]
#Get normal to camera
alpha= ax.azim*np.pi/180.
beta= ax.elev*np.pi/180.
n = np.array([np.cos(alpha)*np.sin(beta),
np.sin(alpha)*np.cos(beta),
np.sin(beta)])
ns = -np.dot(n, [xs, ys, zs])
cs = mpl_toolkits.mplot3d.art3d.zalpha(color, ns)
for i in range(xs.shape[0]):
ax.plot([xs[i]], [ys[i]], [zs[i]], 'o', color=cs[i])
plt.show()
One tricky point is that the camera position ax.elev
and ax.azim
need to be used to work out the normal vector. Also, when you rotate the position of the camera, this will no longer be coloured correctly. To fix this, you could register an update event as follows,
def Update(event):
#Update normal to camera
alpha= ax.azim*np.pi/180.
beta= ax.elev*np.pi/180.
n = np.array([np.cos(alpha)*np.sin(beta),
np.sin(alpha)*np.cos(beta),
np.sin(beta)])
ns = -np.dot(n, [xs, ys, zs])
cs = mpl_toolkits.mplot3d.art3d.zalpha(color, ns)
for i, p in enumerate(points):
p[0].set_alpha(cs[i][3])
fig.canvas.mpl_connect('draw_event', Update)
points = []
for i in range(xs.shape[0]):
points.append(ax.plot([xs[i]], [ys[i]], [zs[i]], 'o', color=cs[i]))