Search code examples
pythonanimationplotlyplotly-python

Plotly. Animated 3D surface plots


I want to make a 3D animation using Surface of Plotly.

However, I run into two issues: (1) When I press play, the figure is only updated at the second frame. (2) I see all the previous frames as well. I just want to see one frame.

What do I need to change? Below is a minimal example, which highlights my issues.

import plotly.graph_objects as go
from plotly.graph_objs import *
import numpy as np

data =  np.random.rand(10, 10,10)
fr = np.arange(10)

layout = go.Layout(
    paper_bgcolor='rgba(0,0,0,0)',
    plot_bgcolor='rgba(0,0,0,0)'
)



fig = go.Figure(layout=layout)

frames = []

for i in range(len(data)):
    z = data[i]
    sh_0, sh_1 = z.shape
    x, y = np.linspace(0, 1, sh_0), np.linspace(0, 1, sh_1)
    
    trace = go.Surface(
        z=z,
        x=x,
        y=y,
        opacity=1,
        colorscale="Viridis",
        colorbar=dict(title="Counts"),
        cmin=0,
        cmax=1
    )
    
    frame = go.Frame(data=[trace], layout=go.Layout(title=f"Frame: {fr[i]}"))
    frames.append(frame)
    
    fig.add_trace(trace)
    fig.frames = frames

fig.update_layout(
    autosize=False,
    width=800,
    height=800,
    margin=dict(l=65, r=50, b=65, t=90)
)

zoom = 1.35

fig.update_layout(
    scene={
        "xaxis": {"nticks": 20},
        "zaxis": {"nticks": 4},
        "camera_eye": {"x": 0.1, "y": 0.4, "z": 0.25},
        "aspectratio": {"x": 0.4 * zoom, "y": 0.4 * zoom, "z": 0.25 * zoom}
    }
)

fig.update_layout(
    updatemenus=[
        dict(
            type='buttons',
            buttons=[
                dict(
                    label='Play',
                    method='animate',
                    args=[None, {'frame': {'duration': 500, 'redraw': True}, 'fromcurrent': True, 'transition': {'duration': 0}}]
                )
            ]
        )
    ]
)

fig.show()

Solution

  • You need to remove fig.add_trace(trace) from the for loop otherwise all traces are drawn on top of of each other. Initially you only need to add (ie. display) the trace of the first frame :

    for i in range(len(data)):
        z = data[i]
        sh_0, sh_1 = z.shape
        x, y = np.linspace(0, 1, sh_0), np.linspace(0, 1, sh_1)
    
        trace = go.Surface(
            z=z,
            x=x,
            y=y,
            opacity=1,
            colorscale="Viridis",
            colorbar=dict(title="Counts"),
            cmin=0,
            cmax=1
        )
    
        frame = go.Frame(data=[trace], layout=go.Layout(title=f"Frame: {fr[i]}"))
        frames.append(frame)
    
        # fig.add_trace(trace)
        # fig.frames = frames
    
    fig.add_trace(frames[0]['data'][0])
    fig.frames = frames