Search code examples
pythonmatplotlibscatter-plotmplot3d

Stacked 2D plots with interconnections in Matplotlib


I need to visualize some complex multivariate datasets and the preferrable choice is to use a modification of parallel axis visualization, using stacked 2D plots, where each plot maps a degree of freedom/model parameter and data points belonging to the same data sets should be interconnected across different plots. I am attaching a conceptual sketch. How could I implement it in matplotlib?

enter image description here


Solution

  • To have a rough idea, this could be a possible solution in matplotlib using Axes3D

    from mpl_toolkits.mplot3d import Axes3D
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.patches import Rectangle, PathPatch
    import mpl_toolkits.mplot3d.art3d as art3d
    
    
    x = np.array([1,2,3])
    y = np.array([2,3,1])
    z = np.array([1,1,1])
    
    
    fig = plt.figure(figsize=(6,6))
    ax = fig.add_axes([0,0,1,1], projection='3d')
    
    #plot the points
    ax.scatter(x,y,z*0.4, c="r", facecolor="r", s=60)
    ax.scatter(y,x,z*0.9, c="b", facecolor="b", s=60)
    ax.scatter(x,y,z*1.6, c="g", facecolor="g", s=60)
    #plot connection lines
    ax.plot([x[0],y[0],x[0]],[y[0],x[0],y[0]],[0.4,0.9,1.6], color="k")
    ax.plot([x[2],y[2],x[2]],[y[2],x[2],y[2]],[0.4,0.9,1.6], color="k")
    #plot planes
    p = Rectangle((0,0), 4,4, color="r", alpha=0.2)
    ax.add_patch(p)
    art3d.pathpatch_2d_to_3d(p, z=0.4, zdir="z")
    
    p = Rectangle((0,0), 4,4, color="b", alpha=0.2)
    ax.add_patch(p)
    art3d.pathpatch_2d_to_3d(p, z=0.9, zdir="z")
    
    p = Rectangle((0,0), 4,4, color="g", alpha=0.2)
    ax.add_patch(p)
    art3d.pathpatch_2d_to_3d(p, z=1.6, zdir="z")
    
    
    
    ax.set_aspect('equal')
    ax.view_init(13,-63)
    ax.set_xlim3d([0,4])
    ax.set_ylim3d([0,4])
    ax.set_zlim3d([0,2])
    
    plt.savefig(__file__+".png")
    plt.show()
    

    enter image description here


    Update

    Creating three different axes is possible. One has to add the axes and make the upper ones transparent (ax2.patch.set_alpha(0.)). Then the grid has to be turned off (ax.grid(False)) and the panes and lines that we don't need set invisible.
    However, I have no clue how to draw a connection with between the axes. The 2D approach of matplotlib.patches.ConnectionPatch does not work for 3D axes.

    from mpl_toolkits.mplot3d import Axes3D
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.patches import Rectangle
    import mpl_toolkits.mplot3d.art3d as art3d
    
    
    x = np.array([1,2,3])
    y = np.array([2,3,1])
    z = np.array([0,0,0])
    
    
    fig = plt.figure(figsize=(6,6))
    
    ax = fig.add_axes([0,0,1,1], projection='3d')
    ax2 = fig.add_axes([0.0,0.24,1,1], projection='3d')
    ax2.patch.set_alpha(0.)
    ax3 = fig.add_axes([0.0,0.48,1,1], projection='3d')
    ax3.patch.set_alpha(0.)
    #plot the points
    ax.scatter(x,y,z, c="r", facecolor="r", s=60)
    ax2.scatter(y*4,x*4,z, c="b", facecolor="b", s=60)
    ax3.scatter(x*100,y*100,z, c="g", facecolor="g", s=60)
    #plot connection lines
    #ax.plot([x[0],y[0],x[0]],[y[0],x[0],y[0]],[0.4,0.9,1.6], color="k")
    #ax.plot([x[2],y[2],x[2]],[y[2],x[2],y[2]],[0.4,0.9,1.6], color="k")
    
    #plot planes
    p = Rectangle((0,0), 4,4, color="r", alpha=0.2)
    ax.add_patch(p)
    art3d.pathpatch_2d_to_3d(p, z=0, zdir="z")
    
    p = Rectangle((0,0), 16,16, color="b", alpha=0.2)
    ax2.add_patch(p)
    art3d.pathpatch_2d_to_3d(p, z=0, zdir="z")
    
    p = Rectangle((0,0), 400,400, color="g", alpha=0.2)
    ax3.add_patch(p)
    art3d.pathpatch_2d_to_3d(p, z=0, zdir="z")
    
    
    
    ax.set_aspect('equal')
    ax2.set_aspect('equal')
    ax3.set_aspect('equal')
    ax.view_init(13,-63)
    ax2.view_init(10,-63)
    ax3.view_init(8,-63)
    ax.set_xlim3d([0,4])
    ax.set_ylim3d([0,4])
    ax.set_zlim3d([0,2])
    ax2.set_xlim3d([0,16])
    ax2.set_ylim3d([0,16])
    ax2.set_zlim3d([0,2])
    ax3.set_xlim3d([0,400])
    ax3.set_ylim3d([0,400])
    ax3.set_zlim3d([0,2])
    ax.grid(False)
    ax2.grid(False)
    ax3.grid(False)
    
    def axinvisible(ax):
        for zax in (ax.w_zaxis, ax.w_xaxis, ax.w_yaxis):
            zax.pane.set_visible(False)
            if zax == ax.w_zaxis:
                zax.line.set_visible(False)
                for ll in zax.get_ticklines()+zax.get_ticklabels():
                        ll.set_visible(False)
    
    
    axinvisible(ax)
    axinvisible(ax2)
    axinvisible(ax3)
    
    # setting a ConnectionPatch does NOT work
    from matplotlib.patches import ConnectionPatch
    con = ConnectionPatch(xyA=(2,2), xyB=(2,2), 
                          coordsA='data', coordsB='data', 
                          axesA=ax, axesB=ax2,
                          arrowstyle='->', clip_on=True)
    ax2.add_artist(con) # artist is not shown :-(
    
    plt.show()
    

    enter image description here