Search code examples
python-3.x3dmayavimayavi.mlabweighted-graph

Mayavi : setting pipeline.tube radius


I'm plotting a 3D network using Mayavi,

edge_size = 0.2
pts = mlab.points3d(x, y, z,
                        scale_mode='none',
                        scale_factor=0.1)

    pts.mlab_source.dataset.lines = np.array(graph.edges())
    tube = mlab.pipeline.tube(pts, tube_radius=edge_size)

I want to change edge/tube radius. So I tried

tube = mlab.pipeline.tube(pts, tube_radius=listofedgeradius)

I get an error that says,

traits.trait_errors.TraitError: The 'tube_radius' trait of a TubeFactory instance must be a float

From the error, I understand a list cannot be assigned to tube_radius. In this case, I am not sure how to assign a different radius to each edge.

Any suggestions on how to assign edge weights/edge radius will be helpful.

EDIT: Complete working example

import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from mayavi import mlab

def main(edge_color=(0.8, 0.8, 0.8), edge_size=0.02):

    t = [1, 2, 3, 4, 5]
    h = [2, 3, 4, 5, 6]

    ed_ls = [(x, y) for x, y in zip(t, h)]
    G = nx.OrderedGraph()
    G.add_edges_from(ed_ls)
    nx.draw(G)
    plt.show()

    graph_pos = nx.spring_layout(G, dim=3)

    # numpy array of x,y,z positions in sorted node order
    xyz = np.array([graph_pos[v] for v in sorted(G)])
    mlab.figure(1)
    mlab.clf()
    pts = mlab.points3d(xyz[:, 0], xyz[:, 1], xyz[:, 2])
    pts.mlab_source.dataset.lines = np.array(G.edges())
    tube = mlab.pipeline.tube(pts, tube_radius=edge_size)
    mlab.pipeline.surface(tube, color=edge_color)

    mlab.show()  # interactive window

main()

New edge weights to be added in the expected output:

   listofedgeradius = [1, 2, 3, 4, 5]
   tube = mlab.pipeline.tube(pts, tube_radius=listofedgeradius)

Solution

  • Is seems to me that you can't plot multiple tubes with different diameter at once. So one solution is to plot them one after another:

    import networkx as nx
    import matplotlib.pyplot as plt
    import numpy as np
    from mayavi import mlab
    
    def main(edge_color=(0.8, 0.8, 0.8)):
    
        t = [1, 2, 4, 4, 5, 3, 5]
        h = [2, 3, 6, 5, 6, 4, 1]
    
        ed_ls = [(x, y) for x, y in zip(t, h)]
        G = nx.OrderedGraph()
        G.add_edges_from(ed_ls)
    
        graph_pos = nx.spring_layout(G, dim=3)
        print(graph_pos)
    
        # numpy array of x,y,z positions in sorted node order
        xyz = np.array([graph_pos[v] for v in sorted(G)])
    
        listofedgeradius = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]) * 0.1
    
        for i, e in enumerate(G.edges()):
    
            # node number of the edge
            i1, i2 = e
    
            # graph_pos is a dictionary
            c1 = graph_pos[i1]
            c2 = graph_pos[i2]
    
            edge_xyz = np.vstack((c1, c2))
    
            pts = mlab.points3d(edge_xyz[:, 0], edge_xyz[:, 1], edge_xyz[:, 2])
    
            #pts.mlab_source.dataset.lines = np.array(G.edges())
            # always first and second point
            pts.mlab_source.dataset.lines = np.array([[0, 1]])
    
            tube = mlab.pipeline.tube(pts, tube_radius=listofedgeradius[i])
    
            mlab.pipeline.surface(tube, color=edge_color)
    
        mlab.gcf().scene.parallel_projection = True
    
        mlab.show()  # interactive window
    
    main()
    

    Here is a larger example with 100 edges (image below) and one caveat of this solution becomes obvious: the for loop is slow.

    import networkx as nx
    import matplotlib.pyplot as plt
    import numpy as np
    from mayavi import mlab
    
    def main(edge_color=(0.8, 0.8, 0.8)):
    
        n = 100
    
        t = np.random.randint(100, size=n)
        h = np.random.randint(100, size=n)
    
        ed_ls = [(x, y) for x, y in zip(t, h)]
        G = nx.OrderedGraph()
        G.add_edges_from(ed_ls)
    
        graph_pos = nx.spring_layout(G, dim=3)
        print(graph_pos)
    
        # numpy array of x,y,z positions in sorted node order
        xyz = np.array([graph_pos[v] for v in sorted(G)])
    
        listofedgeradius = np.random.rand(n) * 0.01
    
        for i, e in enumerate(G.edges()):
    
            print(i)
            # node number of the edge
            i1, i2 = e
    
            # graph_pos is a dictionary
            c1 = graph_pos[i1]
            c2 = graph_pos[i2]
    
            edge_xyz = np.vstack((c1, c2))
    
            pts = mlab.points3d(edge_xyz[:, 0], edge_xyz[:, 1], edge_xyz[:, 2])
    
            #pts.mlab_source.dataset.lines = np.array(G.edges())
            # always first and second point
            pts.mlab_source.dataset.lines = np.array([[0, 1]])
    
            tube = mlab.pipeline.tube(pts, tube_radius=listofedgeradius[i])
    
            mlab.pipeline.surface(tube, color=edge_color)
    
        mlab.gcf().scene.parallel_projection = True
    
        mlab.show()  # interactive window
    
    main()
    

    enter image description here

    Inspired by this, this and this I put together a first example that works well for large graphs (I tried up to 5000 edges). There is still a for loop, but it is not used for plotting, only for gathering the data in numpy arrays, so it's not that bad.

    import networkx as nx
    import matplotlib.pyplot as plt
    import numpy as np
    from mayavi import mlab
    
    def main(edge_color=(0.8, 0.8, 0.8)):
    
        n = 5000
    
        t = np.random.randint(100, size=n)
        h = np.random.randint(100, size=n)
    
        ed_ls = [(x, y) for x, y in zip(t, h)]
        G = nx.OrderedGraph()
        G.add_edges_from(ed_ls)
    
        graph_pos = nx.spring_layout(G, dim=3)
        print(graph_pos)
    
        listofedgeradius = np.random.rand(n) * 0.01
    
        # We create a list of positions and connections, each describing a line.
        # We will collapse them in one array before plotting.
        x = list()
        y = list()
        z = list()
        s = list()
        connections = list()
    
        N = 2 # every edge brings two nodes
    
        # The index of the current point in the total amount of points
        index = 0
    
        for i, e in enumerate(G.edges()):
    
            # node number of the edge
            i1, i2 = e
    
            # graph_pos is a dictionary
            c1 = graph_pos[i1]
            c2 = graph_pos[i2]
    
            edge_xyz = np.vstack((c1, c2))
    
            x.append(edge_xyz[:, 0])
            y.append(edge_xyz[:, 1])
            z.append(edge_xyz[:, 2])
    
            s.append(listofedgeradius[i])
            s.append(listofedgeradius[i])
            # This is the tricky part: in a line, each point is connected
            # to the one following it. We have to express this with the indices
            # of the final set of points once all lines have been combined
            # together, this is why we need to keep track of the total number of
            # points already created (index)
    
            ics = np.vstack(
                            [np.arange(index, index + N - 1.5),
                                np.arange(index + 1, index + N - .5)]
                                    ).T
    
            #print(ics)
            connections.append(ics)
            index += N
    
        # Now collapse all positions, scalars and connections in big arrays
        x = np.hstack(x)
        y = np.hstack(y)
        z = np.hstack(z)
        s = np.hstack(s)
    
        # print(x.shape)
        # print(y.shape)
        # print(z.shape)
        # print(s.shape)
    
        connections = np.vstack(connections)
    
        # # graph_pos is a dictionary
        # c1 = graph_pos[i1]
        # c2 = graph_pos[i2]
    
        # edge_xyz = np.vstack((c1, c2))
    
        #src = mlab.points3d(x, y, z, s)
        #src = mlab.pipeline.scalar_scatter(x, y, z, s)
        src = mlab.plot3d(x, y, z, s)
        print(src)
        print(src.parent)
        print(src.parent.parent)
    
    
        #src.parent.parent.filter.vary_radius = 'vary_radius_by_scalar'    
        src.parent.parent.filter.vary_radius = 'vary_radius_by_absolute_scalar' 
    
        # Connect them
        src.mlab_source.dataset.lines = connections
        #src.update()
    
        # The stripper filter cleans up connected lines
        lines = mlab.pipeline.stripper(src)
    
        # Finally, display the set of lines
        #mlab.pipeline.surface(lines, colormap='Accent', line_width=1, opacity=.4)
    
        #tube = mlab.pipeline.tube(src, tube_radius=0.01)
        #tube.filter.radius_factor = 1
        #tube.filter.vary_radius = 'vary_radius_by_scalar'
        #surf = mlab.pipeline.surface(tube, opacity=0.6, color=(0.8,0.8,0))
    
        #t = mlab.plot3d(x, y, z, s, tube_radius=10)
        #t.parent.parent.filter.vary_radius = 'vary_radius_by_scalar'
    
        #pts.mlab_source.dataset.lines = np.array(G.edges())
        # always first and second point
        #pts.mlab_source.dataset.lines = np.array([[0, 1]])
    
        #tube = mlab.pipeline.tube(src, tube_radius=listofedgeradius[i])
    
        #mlab.pipeline.surface(tube, color=edge_color)
    
        # pts = self.scene.mlab.quiver3d(x, y, z, atomsScales, v, w, 
        # scalars=scalars, mode='sphere', vmin=0.0, vmax=1.0, figure = scene)
        # pts.mlab_source.dataset.lines = bonds
        # tube = scene.mlab.pipeline.tube(pts, tube_radius=0.01)
        # tube.filter.radius_factor = 1
        # tube.filter.vary_radius = 'vary_radius_by_scalar'
        # surf = scene.mlab.pipeline.surface(tube, opacity=0.6, color=(0.8,0.8,0))
    
        #         t = mlab.plot3d(x, y, z, s, tube_radius=10)
            #t.parent.parent.filter.vary_radius = 'vary_radius_by_scalar'
    
        # self.plot = self.scene.mlab.plot3d(x, y, z, t,
        #                                 tube_radius=self.radius, colormap='Spectral')
        #         else:
        #             self.plot.parent.parent.filter.radius = self.radius
    
        mlab.gcf().scene.parallel_projection = True
    
    
        # And choose a nice view
        mlab.view(33.6, 106, 5.5, [0, 0, .05])
        mlab.roll(125)
        mlab.show()
    
    main()
    

    enter image description here