Search code examples
pythonanimationplotly

Plotly animated subplots with px.imshow and go.Scatter


I am trying to create a figure showing image "reconstruction" as function of number of PCs. I want to animate this to show the original image, the cumulative image (over PCs 1,...,i) and the parts that still remain to be "reconstructed". Together with that I want to show the distance between the original and reconstructed image as a function of the number of PCs.

I managed to create the figure below, which animates the scatter plot at the bottom and also the images at the top.

enter image description here

The problem is that once the animation begins the two images on the right "disappear" and I think they appear under the "Original Image"

enter image description here

This is the code I have (creation of animation frames with all 3 images and scatters, and then formation of figure):

import numpy as np
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import plotly.io as pio
from sklearn.decomposition import PCA

pio.templates["custom"] = go.layout.Template(
    layout=go.Layout(
        margin=dict(l=20, r=20, t=40, b=0)
    )
)
pio.templates.default = "simple_white+custom"


class AnimationButtons():
    def play_scatter(frame_duration = 500, transition_duration = 300):
        return dict(label="Play", method="animate", args=
                    [None, {"frame": {"duration": frame_duration, "redraw": False},
                            "fromcurrent": True, "transition": {"duration": transition_duration, "easing": "quadratic-in-out"}}])
    
    def play(frame_duration = 1000, transition_duration = 0):
        return dict(label="Play", method="animate", args=
                    [None, {"frame": {"duration": frame_duration, "redraw": True},
                            "mode":"immediate",
                            "fromcurrent": True, "transition": {"duration": transition_duration, "easing": "linear"}}])
    
    def pause():
        return dict(label="Pause", method="animate", args=
                    [[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate", "transition": {"duration": 0}}])

pca = PCA(n_components=15).fit(X.reshape((X.shape[0], -1)))
pcs = pca.components_.reshape((-1, X.shape[1], X.shape[2]))

img, loadings = X[1], pca.transform(X[1].reshape(-1, 1)).T


reconstructed, distortion, frames = np.zeros_like(X[0]), [], []
for i in range(len(pca.components_)):
    # Reconstruct image using the first i principal components
    reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
    distortion.append(np.sum((img - reconstructed) ** 2))    

    # Append animation frame every 5'th reconstruction
    if i % 2 == 0 or i == pca.n_components_-1:
        frames.append(go.Frame(
            data = [px.imshow(img, binary_string=True).data[0],
                    px.imshow((img - reconstructed).copy(), binary_string=True).data[0],
                    px.imshow(reconstructed.copy(), binary_string=True).data[0],
                    go.Scatter(x=list(range(1, len(distortion)+1)), y=distortion)],
            traces = [0, 1, 2, 3],
            layout = go.Layout(title=rf"$\text{{ Image Reconstruction - Number of PCs: {i+1} }}$")))


fig = make_subplots(rows=2, cols=3, 
                    subplot_titles=["Original Image", "Reconstructed Image", "Remaining Reconstruction", "Distortion Level"],
                    specs=[[{}, {}, {}], [{"colspan": 3}, None, None]], row_heights=[500, 200],)
fig.add_traces(data=frames[0]["data"], rows = [1,1,1,2], cols = [1,2,3,1])
fig.update(frames=frames)

fig.update_layout(title=frames[0]["layout"]["title"],
                  xaxis4=dict(range=[0, 50], autorange=False),
                  yaxis4=dict(range=[0, max(distortion)+1], autorange=False),
                  margin = dict(t = 100),
                  width=800,
                  updatemenus=[dict(type="buttons", buttons=[AnimationButtons.play(), AnimationButtons.pause()])])
fig.show()

I tried finding similar questions but wasn't able to find anything that would work for the showing of both px.imshow and go.Scatter with subplots and animation.

The data X are the MNIST digits images after centering. Here is a numpy array with one such image: (X.shape=(16,5,5) - 16 images of 5x5 - animation only on first image)

X=np.array( [[[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]],

 [[ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
    0.00000000e+00],
  [ 0.00000000e+00, 0.00000000e+00, 0.00000000e+00, 0.00000000e+00,
   -1.04166667e-06],
  [ 0.00000000e+00, 0.00000000e+00,-4.16666667e-06,-2.73437500e-06,
   -2.71484375e-05],
  [ 0.00000000e+00, 0.00000000e+00,-1.26302083e-05,-2.28515625e-05,
   -4.69401042e-05],
  [ 0.00000000e+00,-2.47395833e-06,-2.03776042e-05,-5.60546875e-05,
   -3.15950521e-04]]] )

Placed the above code in a Jupyter notebook on GitHub


Solution

  • Similar to what jayvessea suggested, I ended up playing with the structure of the px.imshow. I first created the px.imshow with both facets and animation, and then added to it both the scatter plot and the desired layout

    pca = PCA(n_components=50).fit(X.reshape((X.shape[0], -1)))
    pcs = pca.components_.reshape((-1, X.shape[1], X.shape[2]))
    
    img, loadings = X[150], pca.transform(X[150].reshape(-1, 1)).T
    
    reconstructed, distortion, images, scatters, titles = np.zeros_like(X[0]), [], [], [], []
    for i in range(len(pca.components_)):
        # Reconstruct image using the first i principal components
        reconstructed += loadings[i].reshape(img.shape) * pca.components_[i].reshape(img.shape)
        distortion.append(np.sum((img - reconstructed) ** 2))    
    
        # Append animation frame every other reconstruction
        if i % 2 == 0 or i == pca.n_components_-1:
            images.append([img.copy(), reconstructed.copy(), (img - reconstructed).copy()])
            scatters.append(go.Scatter(x=list(range(1, len(distortion)+1)), y=distortion, name=3, xaxis="x4", yaxis="y4", marker_color="black"))
            titles.append(rf"$\text{{ Image Reconstruction - Number of PCs: {i+1} }}$")
    
    
            
    # Create figure on the basis of the animated facetted imshow figure
    fig = px.imshow(np.array(images), facet_col=1, animation_frame=0, binary_string=True)
    for i, (scatter, title) in enumerate(zip(*[scatters, titles])):
        fig["frames"][i]["data"] += (scatter, )
        fig["frames"][i]["traces"] = [0,1,2,3]
        fig["frames"][i]["layout"]["title"] = title 
    fig.add_traces(data=fig["frames"][0]["data"][-1])
    
    # Create "template" figure to transfer layout onto the `fig` figure
    layout = make_subplots(rows=2, cols=3, 
                           subplot_titles=["Original Image", "Reconstructed Image", "Remaining Reconstruction", "Distortion Level"],
                           specs=[[{"type":"Image"}, {"type":"Image"}, {"type":"Image"}], [{"type":"xy","colspan": 3}, None, None]], row_heights=[500, 200],)
    
    layout.update_layout(title=titles[0],
                         xaxis4=dict(range=[0, 50], autorange=False),
                         yaxis4=dict(range=[0, max(distortion)+1], autorange=False),
                         margin = dict(t = 100), width=800,
                         updatemenus=[dict(type="buttons", buttons=[AnimationButtons.play(), AnimationButtons.pause()])])
    
    fig["layout"] = layout["layout"]
    fig
    

    It is not a very elegant solution but it is a sufficient workaround.

    enter image description here