Search code examples
dataframeplotlysubplot

Plot rows of df on Plotly


The df is like this:

          X             Y            Label
0  [16, 37, 38]  [7968, 4650, 3615]   0.7
1  [29, 37, 12]  [4321, 4650, 1223]   0.8
2  [12, 2, 445]  [1264, 3456, 2112]   0.9

This should plot three lines on the same plot with labels as continuous variables. What is the fastest & simplest way to plot it using plotly?


Solution

  • Taking This should plot three lines on the same plot as the requirement. (Which is inconsistent with where I want subplots from each row of the df)

    Simple case of create a trace for each row, using https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.explode.html to prepare x and y

    import pandas as pd
    import plotly.graph_objects as go
    
    df = pd.DataFrame(
        {
            "X": [[16, 37, 38], [29, 37, 12], [12, 2, 445]],
            "Y": [[7968, 4650, 3615], [4321, 4650, 1223], [1264, 3456, 2112]],
            "Label": [0.7, 0.8, 0.9],
        }
    )
    
    go.Figure(
        [
            go.Scatter(
                x=r["X"].explode(), y=r["Y"].explode(), name=str(r["Label"].values[0])
            )
            for _, r in df.groupby(df.index)
        ]
    )
    

    enter image description here

    with continuous color defined by label

    import pandas as pd
    import plotly.graph_objects as go
    from plotly.colors import sample_colorscale
    import plotly.express as px
    
    df = pd.DataFrame(
        {
            "X": [[16, 37, 38], [29, 37, 12], [12, 2, 445]],
            "Y": [[7968, 4650, 3615], [4321, 4650, 1223], [1264, 3456, 2112]],
            "Label": [0.1, 0.5, 0.9],
        }
    )
    
    fig = px.scatter(x=[0], y=[0], color=[.5], color_continuous_scale="YlGnBu")
    
    
    fig = fig.add_traces(
        [
            go.Scatter(
                x=r["X"].explode(),
                y=r["Y"].explode(),
                name=str(r["Label"].values[0]),
                line_color=sample_colorscale("YlGnBu", r["Label"].values[0])[0],
                showlegend=False
            )
            for _, r in df.groupby(df.index)
        ]
    )
    
    fig