Search code examples
pythonmatplotlibmplot3d

How can I connect two points in 3D scatter plot with arrow?


I am trying to connect two points in a 3D scatter plot with an arrow. I tried using quiver with the following code

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(20, 15))
ax = fig.add_subplot(111, projection='3d')

xs1 = [40,50,34]
ys1 = [30,30,30]
zs1 = [98,46,63]

ax.scatter(xs1, ys1, zs1, s=100, c='g', marker='o')

ax.text(xs1[0], ys1[0], zs1[0], '(%s,%s,%s)' % (str(xs1[0]), str(ys1[0]), str(zs1[0])), size=30, zorder=5, color='k')
ax.text(xs1[1], ys1[1], zs1[1], '(%s,%s,%s)' % (str(xs1[1]), str(ys1[1]), str(zs1[1])), size=30, zorder=5, color='k')
ax.text(xs1[2], ys1[2], zs1[2], '(%s,%s,%s)' % (str(xs1[2]), str(ys1[2]), str(zs1[2])), size=30, zorder=5, color='k')


ax.quiver(xs1[0], ys1[0], zs1[0], (xs1[0]-xs1[2]), (ys1[0]-ys1[2]), (zs1[0]-zs1[2]), length=30)


ax.set_xlabel('X', fontsize=30, labelpad=20)
ax.set_ylabel('Y', fontsize=30, labelpad=20)
ax.set_zlabel('Z', fontsize=30, labelpad=20)

plt.show()
fig.canvas.draw()
fig.canvas.flush_events()

Which looks something like this

enter image description here

The only way I was able to connect the two dots is by increasing the length of quiver, but if the length is more than what is required, it just passes through the other point.

Is there any better way to connect them without adjusting the length?


Solution

  • Construct own class based on FancyArrow to plot in 3d dimension:

    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    from mpl_toolkits.mplot3d import proj3d
    from matplotlib.patches import FancyArrowPatch
    
    class Arrow3D(FancyArrowPatch):
        def __init__(self, xs, ys, zs, *args, **kwargs):
            FancyArrowPatch.__init__(self, (0,0), (0,0), *args, **kwargs)
            self._verts3d = xs, ys, zs
    
        def draw(self, renderer):
            xs3d, ys3d, zs3d = self._verts3d
            xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, renderer.M)
            self.set_positions((xs[0],ys[0]),(xs[1],ys[1]))
            FancyArrowPatch.draw(self, renderer)
    
    fig = plt.figure(figsize=(20, 15))
    ax = fig.add_subplot(111, projection='3d')
    
    xs1 = [40,50,34]
    ys1 = [30,30,30]
    zs1 = [98,46,63]
    
    ax.scatter(xs1, ys1, zs1, s=100, c='g', marker='o')
    
    ax.text(xs1[0], ys1[0], zs1[0], '(%s,%s,%s)' % (str(xs1[0]), str(ys1[0]), str(zs1[0])), size=30, zorder=5, color='k')
    ax.text(xs1[1], ys1[1], zs1[1], '(%s,%s,%s)' % (str(xs1[1]), str(ys1[1]), str(zs1[1])), size=30, zorder=5, color='k')
    ax.text(xs1[2], ys1[2], zs1[2], '(%s,%s,%s)' % (str(xs1[2]), str(ys1[2]), str(zs1[2])), size=30, zorder=5, color='k')
    
    arw = Arrow3D([xs1[0],xs1[2]],[ys1[0],ys1[2]],[zs1[0],zs1[2]], arrowstyle="->", color="purple", lw = 3, mutation_scale=25)
    ax.add_artist(arw)
    
    ax.set_xlabel('X', fontsize=30, labelpad=20)
    ax.set_ylabel('Y', fontsize=30, labelpad=20)
    ax.set_zlabel('Z', fontsize=30, labelpad=20)
    
    plt.show()
    

    enter image description here

    Above code is based on this answer read it for more details.