Search code examples
pythonmatplotlibmatplotlib-animation

Update the color of current rectangle in matplot animation


I want to animate the Bubble Sort algorithm with FuncAnimationfrom matplotlib. My idea is to generate a random list of integers and initialize a bar chart plot with the values. Then I apply the update_fig function with the generator values for my bubble_sort function. So far so good, I am able to animate the whole Bubble Sort algorithm and at the end my array with numbers is sorted.

But for better readability I want to change the color of the current rectangle. For example: All rects should be blue, the current one should be red.

Here is my code:

import random
from matplotlib import pyplot as plt
from matplotlib import animation

def update_fig(a, rects, iteration):
    """
    Function to update the plot
    :param a: the array with numbers
    :param rects: barrect for each number in the array
    :param iteration: the current iteration
    :return: None
    """
    for rect, val in zip(rects, a):
        rect.set_height(val)
        rect.set_color('grey')
    iteration[0] += 1
    text.set_text(f'# of operations {iteration[0]}')


def bubble_sort(a):
    """
    In-place Bubble Sort
    :param a: shuffled array
    :return: rearranged array
    """
    for i in range(len(a)):
        for j in range(0, len(a) - i - 1):
            if a[j] > a[j + 1]:
                tmp = a[j]
                a[j] = a[j + 1]
                a[j + 1] = tmp
            yield a

a = list(range(0,100))
random.shuffle(a)
generator = bubble_sort(a)
fig, ax = plt.subplots()
ax.set_title('Bubble Sort')
bar_rects = ax.bar(range(len(a)), a, align='edge')
ax.set_xlim(0, 100)
ax.set_ylim(0, int(1.07 * 100))

text = ax.text(0.02, 0.95, "", transform=ax.transAxes)
iteration = [0]
animate = animation.FuncAnimation(fig,
                                  func=update_fig,
                                  fargs=(bar_rects, iteration),
                                  frames=generator,
                                  interval=1,
                                  repeat=False,
                                  save_count=1500
                                  )
plt.show()

The result should be something like this.

I already tried to set rect.set_color('orange') in my update_fig function but this changes the color of all rectangles within the plot at the beginning.

Any ideas how to to this?


Solution

  • An approach could be to store the current value into a global variable, which can be checked during coloring. And, while we're at it, also make iteration a global variable.

    import random
    from matplotlib import pyplot as plt
    from matplotlib import animation
    
    iteration = 0
    current_val = None
    
    def update_fig(a, rects):
         """
         Function to update the plot
         :param a: the array with numbers
         :param rects: barrect for each number in the array
         :param iteration: the current iteration
         :return: None
         """
         global iteration, current_val
         for rect, val in zip(rects, a):
              rect.set_height(val)
              rect.set_color('tomato' if val == current_val else 'grey')
         iteration += 1
         text.set_text(f'# of operations {iteration}')
    
    def bubble_sort(a):
         """
         In-place Bubble Sort
         :param a: shuffled array
         :return: rearranged array
         """
         global current_val
         for i in range(len(a)):
              for j in range(0, len(a) - i - 1):
                   if a[j] > a[j + 1]:
                        tmp = a[j]
                        a[j] = a[j + 1]
                        a[j + 1] = tmp
                   current_val = a[j + 1]
                   yield a
    
    a = list(range(0, 100))
    random.shuffle(a)
    generator = bubble_sort(a)
    fig, ax = plt.subplots()
    ax.set_title('Bubble Sort')
    bar_rects = ax.bar(range(len(a)), a, align='edge')
    ax.set_xlim(0, 100)
    ax.set_ylim(0, int(1.07 * 100))
    
    text = ax.text(0.02, 0.95, "", transform=ax.transAxes)
    animate = animation.FuncAnimation(fig,
                                      func=update_fig,
                                      fargs=(bar_rects,),
                                      frames=generator,
                                      interval=1,
                                      repeat=False,
                                      save_count=1500
                                      )
    plt.show()