Search code examples
pythonplotlyplotly-pythonscatter3d

Python Plotly: How to add an image to a 3D scatter plot


I am trying to visualize multiple 2d trajectories (x, y) in a 3D scatter plot where the z axis is time.

import numpy as np
import pandas as pd
import plotly.express as px

# Sample data: 3 trajectories
t = np.linspace(0, 10, 200)
df = pd.concat([pd.DataFrame({'x': 900 * (1 + np.cos(t + 5 * i)), 'y': 400 * (1 + np.sin(t)), 't': t, 'id': f'id000{i}'}) for i in [0, 1, 2]])
# 3d scatter plot
fig = px.scatter_3d(df, x='x', y='y', z='t', color='id', )
fig.update_traces(marker=dict(size=2))
fig.show()

Original 3D scatter plot

I have a .png file of a map with size: 2000x1000. The (x, y) coordinates of the trajectories correspond to the pixel locations of the map.
I would like to see the image of the map on the "floor" of the 3d scatter plot.

I have tried to add the image with this code:

from scipy import misc

img = misc.imread('images/map_bg.png')
fig2 = px.imshow(img)
fig.add_trace(fig2.data[0])
fig.show()

But the result is having an independent image in the background as a separate plot: Bad result

And I want the image on the "floor" of the scatter plot and moving with the scatter plot, if I rotate/zoom. Here is a mock: enter image description here

Additional note: There can be any number of trajectories and for my application, it is important that each trajectory is automatically plotted with a different color. I am using plotly.express, but I can use other plotly packages, as long as these requirements are met.


Solution

  • I've ran into the same situation where I wanted to use an image as a bottom surface in a 3D scatterplot. With help from two posts here and here, I've been able to create the following 3d scatter plot:

    enter image description here

    I've used plotly go in my example, so the result is a little bit different than the code from the OP.

    import numpy as np
    import pandas as pd
    from PIL import Image
    import plotly.graph_objects as go
    from scipy import misc
    
    im = misc.face()
    im_x, im_y, im_layers = im.shape
    eight_bit_img = Image.fromarray(im).convert('P', palette='WEB', dither=None)
    dum_img = Image.fromarray(np.ones((3,3,3), dtype='uint8')).convert('P', palette='WEB')
    idx_to_color = np.array(dum_img.getpalette()).reshape((-1, 3))
    colorscale=[[i/255.0, "rgb({}, {}, {})".format(*rgb)] for i, rgb in enumerate(idx_to_color)]
    
    # Sample data: 3 trajectories
    t = np.linspace(0, 10, 200)
    df = pd.concat([pd.DataFrame({'x': 400 * (1 + np.cos(t + 5 * i)), 'y': 400 * (1 + np.sin(t)), 't': t, 'id': f'id000{i}'}) for i in [0, 1, 2]])
    # im = im.swapaxes(0, 1)[:, ::-1]
    colors=df['t'].to_list()
    
    # # 3d scatter plot
    x = np.linspace(0,im_x, im_x)
    y = np.linspace(0, im_y, im_y)
    z = np.zeros(im.shape[:2])
    fig = go.Figure()
    
    fig.add_trace(go.Scatter3d(
        x=df['x'], 
        y=df['y'], 
        z=df['t'],
        marker=dict(
            color=colors,
            size=4,
        )
        ))
    
    fig.add_trace(go.Surface(x=x, y=y, z=z,
        surfacecolor=eight_bit_img, 
        cmin=0, 
        cmax=255,
        colorscale=colorscale,
        showscale=False,
        lighting_diffuse=1,
        lighting_ambient=1,
        lighting_fresnel=1,
        lighting_roughness=1,
        lighting_specular=0.5,
    
    ))
    
    fig.update_layout(
        title="My 3D scatter plot",
        width=800,
        height=800,
        scene=dict(xaxis_visible=True,
                    yaxis_visible=True, 
                    zaxis_visible=True, 
                    xaxis_title="X",
                    yaxis_title="Y",
                    zaxis_title="Z" ,
    
        ))
    
    
    fig.show()