Search code examples
pythonanimationmatplotlibscatter3d

Python: Animated 3D Scatterplot gets slow


My program plots the positions of particles in my file for every time step. Unfortunately it gets slower and slower although I used matplotlib.animation. Where is the bottleneck?

My data file for two particles looks like the following:

#     x   y   z
# t1  1   2   4
#     4   1   3
# t2  4   0   4
#     3   2   9
# t3  ...

My script:

import numpy as np                          
import matplotlib.pyplot as plt            
from mpl_toolkits.mplot3d import Axes3D
import mpl_toolkits.mplot3d.axes3d as p3
import matplotlib.animation as animation

# Number of particles
numP = 2
# Dimensions
DIM = 3
timesteps = 2000

with open('//home//data.dat', 'r') as fp:
    particleData = []
    for line in fp:
        line = line.split()
        particleData.append(line)

x = [float(item[0]) for item in particleData]
y = [float(item[1]) for item in particleData]
z = [float(item[2]) for item in particleData]      

# Attaching 3D axis to the figure
fig = plt.figure()
ax = p3.Axes3D(fig)

# Setting the axes properties
border = 1
ax.set_xlim3d([-border, border])
ax.set_ylim3d([-border, border])
ax.set_zlim3d([-border, border])


def animate(i):
    global x, y, z, numP
    #ax.clear()
    ax.set_xlim3d([-border, border])
    ax.set_ylim3d([-border, border])
    ax.set_zlim3d([-border, border])
    idx0 = i*numP
    idx1 = numP*(i+1)
    ax.scatter(x[idx0:idx1],y[idx0:idx1],z[idx0:idx1])

ani = animation.FuncAnimation(fig, animate, frames=timesteps, interval=1, blit=False, repeat=False)
plt.show()

Solution

  • I would suggest to use pyqtgraph in this case. Citation from the docs:

    Its primary goals are 1) to provide fast, interactive graphics for displaying data (plots, video, etc.) and 2) to provide tools to aid in rapid application development (for example, property trees such as used in Qt Designer).

    You can check out some examples after the installation:

    import pyqtgraph.examples
    pyqtgraph.examples.run()
    

    This small code snippet generates 1000 random points and displays them in a 3D scatter plot by constantly updating the opacity, similar to the 3D scatter plot example in pyqtgraph.examples:

    from pyqtgraph.Qt import QtCore, QtGui
    import pyqtgraph.opengl as gl
    import numpy as np
    
    app = QtGui.QApplication([])
    w = gl.GLViewWidget()
    w.show()
    g = gl.GLGridItem()
    w.addItem(g)
    
    #generate random points from -10 to 10, z-axis positive
    pos = np.random.randint(-10,10,size=(1000,3))
    pos[:,2] = np.abs(pos[:,2])
    
    sp2 = gl.GLScatterPlotItem(pos=pos)
    w.addItem(sp2)
    
    #generate a color opacity gradient
    color = np.zeros((pos.shape[0],4), dtype=np.float32)
    color[:,0] = 1
    color[:,1] = 0
    color[:,2] = 0.5
    color[0:100,3] = np.arange(0,100)/100.
    
    def update():
        ## update volume colors
        global color
        color = np.roll(color,1, axis=0)
        sp2.setData(color=color)
    
    t = QtCore.QTimer()
    t.timeout.connect(update)
    t.start(50)
    
    
    ## Start Qt event loop unless running in interactive mode.
    if __name__ == '__main__':
        import sys
        if (sys.flags.interactive != 1) or not hasattr(QtCore, PYQT_VERSION'):
            QtGui.QApplication.instance().exec_()
    

    Small gif to give you an idea of the performance:

    enter image description here

    EDIT:

    Displaying multiple points at every single time step is a little bit tricky since the gl.GLScatterPlotItem takes only (N,3)-arrays as point locations, see here. You could try to make a dictionary of ScatterPlotItems where each of them includes all time steps for a specific point. Then one would need to adapt the update function accordingly. You can find an example below where pos is an (100,10,3)-array representing 100 time steps for each point. I reduced the update time to 1000 ms for a slower animation.

    from pyqtgraph.Qt import QtCore, QtGui
    import pyqtgraph.opengl as gl
    import numpy as np
    
    app = QtGui.QApplication([])
    w = gl.GLViewWidget()
    w.show()
    g = gl.GLGridItem()
    w.addItem(g)
    
    pos = np.random.randint(-10,10,size=(100,10,3))
    pos[:,:,2] = np.abs(pos[:,:,2])
    
    ScatterPlotItems = {}
    for point in np.arange(10):
        ScatterPlotItems[point] = gl.GLScatterPlotItem(pos=pos[:,point,:])
        w.addItem(ScatterPlotItems[point])
    
    color = np.zeros((pos.shape[0],10,4), dtype=np.float32)
    color[:,:,0] = 1
    color[:,:,1] = 0
    color[:,:,2] = 0.5
    color[0:5,:,3] = np.tile(np.arange(1,6)/5., (10,1)).T
    
    def update():
        ## update volume colors
        global color
        for point in np.arange(10):
            ScatterPlotItems[point].setData(color=color[:,point,:])
        color = np.roll(color,1, axis=0)
    
    t = QtCore.QTimer()
    t.timeout.connect(update)
    t.start(1000)
    
    
    ## Start Qt event loop unless running in interactive mode.
    if __name__ == '__main__':
        import sys
        if (sys.flags.interactive != 1) or not hasattr(QtCore, 'PYQT_VERSION'):
        QtGui.QApplication.instance().exec_()
    

    Keep in mind that in this examples, all points are shown in the scatter plot, however, the color opacity (4th dimension in the color array) is updated in every time step to get an animation. You could also try to update the points instead of the color to get better performance...