Search code examples
pythonmatplotlibsubplotmatplotlib-3d

How to connect points between 3d subplots


Trying to draw a line connecting a point on a 3D subplot to another 3D subplot. In 2D this is easy to do using ConnectionPatch. I've tried to mimic the Arrow3D class from here without luck.

I'm happy for even just a work-around at this point. As an example, in the figure generated by the code below I would want to connect the two green dots.

def cylinder(r, n):
    '''
    Returns the unit cylinder that corresponds to the curve r.
    INPUTS:  r - a vector of radii
             n - number of coordinates to return for each element in r

    OUTPUTS: x,y,z - coordinates of points
    '''

    # ensure that r is a column vector
    r = np.atleast_2d(r)
    r_rows, r_cols = r.shape

    if r_cols > r_rows:
        r = r.T

    # find points along x and y axes
    points = np.linspace(0, 2*np.pi, n+1)
    x = np.cos(points)*r
    y = np.sin(points)*r

    # find points along z axis
    rpoints = np.atleast_2d(np.linspace(0, 1, len(r)))
    z = np.ones((1, n+1))*rpoints.T

    return x, y, z


#---------------------------------------
# 3D example
#---------------------------------------
fig = plt.figure()

# top figure
ax = fig.add_subplot(2,1,1, projection='3d')
x,y,z = cylinder(np.linspace(2,1,num=10), 40)
for i in range(len(z)):
    ax.plot(x[i], y[i], z[i], 'c')
ax.plot([2], [0], [0],'go')

# bottom figure
ax2 = fig.add_subplot(2,1,2, projection='3d')
x,y,z = cylinder(np.linspace(0,1,num=10), 40)
for i in range(len(z)):
    ax2.plot(x[i], y[i], z[i], 'r')
ax2.plot([1], [0], [1],'go')

plt.show()

Solution

  • I was trying to solve a very similar problem just tonight! Some of the code may be unnecessary but it will give you the main idea... ...I hope

    Inspiration from: http://hackmap.blogspot.com.au/2008/06/pylab-matplotlib-imagemap.html and other many and varied sources over the last two hours...

    #! /usr/bin/env python
    
    import numpy as np
    import matplotlib.pyplot as plt
    from mpl_toolkits.mplot3d import Axes3D
    from mpl_toolkits.mplot3d import proj3d
    import matplotlib
    
    N = 50
    x = np.random.rand(N)
    y = np.random.rand(N)
    z = np.random.rand(N)
    
    # point's to join
    p1 = 10
    p2 = 20
    
    fig = plt.figure()
    
    # a background axis to draw lines on
    ax0 = plt.axes([0.,0.,1.,1.])
    ax0.set_xlim(0,1)
    ax0.set_ylim(0,1)
    
    # use these to know how to transform the screen coords
    dpi = ax0.figure.get_dpi()
    height = ax0.figure.get_figheight() * dpi
    width = ax0.figure.get_figwidth() * dpi
    
    # first scatter plot
    ax1 = plt.axes([0.05,0.05,0.9,0.425], projection='3d')
    ax1.scatter(x, y, z)
    
    # one point of interest
    ax1.scatter(x[p1], y[p1], z[p1], s=100.)
    x1, y1, _ = proj3d.proj_transform(x[p1], y[p1], z[p1], ax1.get_proj())
    [x1,y1] = ax1.transData.transform((x1, y1))  # convert 2d space to screen space
    # put them in screen space relative to ax0
    x1 = x1/width
    y1 = y1/height
    
    # second scatter plot (same data)
    ax2 = plt.axes([0.05,0.475,0.9,0.425], projection='3d')
    ax2.scatter(x, y, z)
    
    # another point of interest
    ax2.scatter(x[p2], y[p2], z[p2], s=100.)
    x2, y2, _ = proj3d.proj_transform(x[p2], y[p2], z[p2], ax2.get_proj())
    [x2,y2] = ax2.transData.transform((x2, y2))  # convert 2d space to screen space
    x2 = x2/width
    y2 = y2/height
    
    
    # set all these guys to invisible (needed?, smartest way?)
    for item in [fig, ax1, ax2]:
        item.patch.set_visible(False)
    
    # draw a line between the transformed points
    # again, needed? I know it works...
    
    transFigure = fig.transFigure.inverted()
    
    coord1 = transFigure.transform(ax0.transData.transform([x1,y1]))
    coord2 = transFigure.transform(ax0.transData.transform([x2,y2]))
    
    line = matplotlib.lines.Line2D((coord1[0],coord2[0]),(coord1[1],coord2[1]),
                                   transform=fig.transFigure)
    fig.lines = line,
    
    plt.show()
    

    success