Search code examples
pythonmatplotlibscatter-plotmatplotlib-animation

How to animate a scatter plot with variable number of points?


I'm trying to animate a scatter plot but with a variable number of points at each iteration of my animation.

Animating a scatter plot has been addressed before (e.g., here and here). However, the number of points is always assumed to be fixed. For example, if Axes3D is used, then axes3d.scatter._offsets3d won't work if the number of points are different in each iteration of FuncAnimation.

How can I animate a scatter plot when each animation iteration contains a different number of points?


Solution

  • If your goal is to create an animated plot that contains a variable number of points on each frame of the animation, the following code might do the trick:

    # == Import Required Libraries =================================================
    import numpy as np
    import matplotlib.pyplot as plt
    from matplotlib.animation import FuncAnimation
    from IPython.display import HTML
    
    
    # ATTENTION: Uncomment this line, if you're running on a Jupyter Notebook
    # %matplotlib notebook
    
    
    def next_frame(total_frame_count: int = 10):
        """
        Generate random 2D frames of varying dimensions.
    
        This function serves as a generator that yields 2D numpy arrays with random
        values. The dimensions of these arrays range between 5 and 50, inclusive.
        The generator will continue yielding arrays until the total_frame_count
        reaches zero.
    
        Parameters
        ----------
        total_frame_count : int, optional
            The number of frames to be generated. The default value is 10.
    
        Yields
        ------
        array_like
            2D numpy array with random values. The dimensions of the array range
            between 5 and 50, inclusive. In other words, the number points
            each frame of the animation will have varies between 5 and 50.
    
        Examples
        --------
        Use this function in a for-loop to generate and process frames:
    
        >>> frame_generator = next_frame(3)
        >>> for frame in frame_generator:
        >>>     print(frame.shape)
        (30, 2)
        (12, 2)
        (48, 2)
    
        Notes
        -----
        This function can be used to generate frames for an animation iteratively.
        """
        while total_frame_count > 0:
            yield np.random.rand(np.random.randint(5, 50), 2)
            total_frame_count -= 1
    
    
    def update(frame):
        """
        Update a scatter plot with new data.
    
        This function clears the current scatter plot in the 'ax' Axes object,
        sets the plot limits, and then creates a new scatter plot based on the
        provided 2D frame. The 'ax' Axes object must be pre-defined.
    
        Parameters
        ----------
        frame : array_like
            A 2D array where each row represents a point in the scatter plot.
            The first column represents the x-values, and the second column
            represents the y-values.
    
        Returns
        -------
        scat : PathCollection
            A collection of paths that make up the scatter plot.
    
        Raises
        ------
        NameError
            If 'ax' isn't defined in the scope where this function is called.
    
        Examples
        --------
        This function can be used in animation generation:
    
        >>> import matplotlib.pyplot as plt
        >>> import matplotlib.animation as animation
        >>> fig, ax = plt.subplots()
        >>> ani = animation.FuncAnimation(fig, update, frames=next_frame(3))
        >>> plt.show()
        """
    
        # Clear the last scatter plot
        ax.clear()
    
        # Set the limits of your plot again
        # NOTE: You might want to dynamically set these limits based on the new frame
        #       values that you're plotting.
        ax.set_xlim(0, 1)
        ax.set_ylim(0, 1)
    
        # Plot the new scatter plot
        scat = ax.scatter(frame[:, 0], frame[:, 1])
    
        return scat
    
    
    # == Create the Animation ======================================================
    # Create a figure and an Axes object
    fig, ax = plt.subplots()
    
    # Set the initial limits of your plot
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    
    # Create the animation object
    ani = FuncAnimation(fig, update, frames=next_frame, blit=True)
    
    # Convert the animation to HTML5 video
    video = ani.to_html5_video()
    
    # Display the video
    HTML(video)
    

    Note that the axis limits can also be modified on each frame.