Search code examples
pythonfor-loopplotly

Create plotly pie subplots using for loop


I am trying to create pie subplots using for loop.

from plotly.subplots import make_subplots
import plotly.graph_objects as go
import math

fig = make_subplots(rows=7, cols=2)
for i,year in enumerate(sorted(df.Year.unique())):
    grouped=df[df.Year==year].groupby('Commodity')['TradeValue(Us$)'].sum().reset_index().sort_values('TradeValue(Us$)',ascending=False).head(10)

    fig.add_trace(go.Pie(values=grouped['TradeValue(Us$)'],labels=grouped['Commodity'],domain=dict(x=[1,1])),row=i%7+1,col=math.floor(i/7+1))
            
    #fig.update_layout(height=1800, width=1000, title_text="Commodity share in "+str(year))
    #fig.update_xaxes(showticklabels=False)
fig.show()

I getthis error: Trace type 'pie' is not compatible with subplot type 'xy' at grid position (1, 1)


Solution

  • You need to define type of each sub-plot

    fig = make_subplots(
        rows=7,
        cols=2,
        specs=[[{"type": "domain"} for _ in range(2)] for _ in range(7)],
    )
    

    full MWE

    Using your code and generating a compatible data frame

    from plotly.subplots import make_subplots
    import plotly.graph_objects as go
    import math
    import pandas as pd
    import numpy as np
    
    df = pd.DataFrame(
        {
            "Year": np.repeat([y for y in range(2012, 2023)], 5),
            "Commodity": np.random.choice(["Gold", "Silver", "Wheat"], 55),
            "TradeValue(Us$)": np.random.uniform(100, 10**5, 55),
        }
    )
    
    fig = make_subplots(
        rows=7,
        cols=2,
        specs=[[{"type": "domain"} for _ in range(2)] for _ in range(7)],
    )
    
    for i, year in enumerate(sorted(df.Year.unique())):
        grouped = (
            df[df.Year == year]
            .groupby("Commodity")["TradeValue(Us$)"]
            .sum()
            .reset_index()
            .sort_values("TradeValue(Us$)", ascending=False)
            .head(10)
        )
    
        fig.add_trace(
            go.Pie(
                values=grouped["TradeValue(Us$)"],
                labels=grouped["Commodity"],
                domain=dict(x=[1, 1]),
            ),
            row=i % 7 + 1,
            col=math.floor(i / 7 + 1),
        )
    
        # fig.update_layout(height=1800, width=1000, title_text="Commodity share in "+str(year))
        # fig.update_xaxes(showticklabels=False)
    fig.show()