Search code examples
pythonplotly

How to make Plotly animated chart display all the categories (not only ones present in first frame)


To illustrate my problem, I use sample Plotly animated chart included on their website:

https://plotly.com/python/animations/

Orginal code:

df = px.data.gapminder()
px.scatter(df, x="gdpPercap", y="lifeExp", animation_frame="year", animation_group="country",
           size="pop", color="continent", hover_name="country",
           log_x=True, size_max=55, range_x=[100,100000], range_y=[25,90])

It displays all the countries but when we remove data about European countries older than 1970, Europe is not displayed on the legend.

It's even worse when I remove data about Asian countries. It changes frame range even though it has complete data for remaining contintents.

How to force this chart to include a continent in the Legend always when it appears in at least one frame?

import plotly.express as px
df = px.data.gapminder()
df = df.drop(df[(df.continent == 'Europe') & (df.year < 1970)].index) #here I remove data
px.scatter(df, x="gdpPercap", y="lifeExp", animation_frame="year", animation_group="country",
           size="pop", color="continent", hover_name="country",
           log_x=True, size_max=55, range_x=[100,100000], range_y=[25,90])```


  [1]: https://i.sstatic.net/5rci1.png

Solution

  • There are limitations with plotly animations in that it is expected dimensions are consistent. This can be achieved by doing an appropriate outer join to ensure all combinations are present. Taking your sample code as starting point, then use pandas to ensure there is dimension consistency before do px.scatter()

    import pandas as pd
    import plotly.express as px
    df = px.data.gapminder()
    df = df.drop(df[(df.continent == 'Europe') & (df.year < 1970)].index) #here I remove data
    
    # ensure all combinations of attributes are present
    key_cols = ["continent", "year", "country"]
    
    df2 = pd.DataFrame(
        index=pd.MultiIndex.from_product(
            [df[col].unique() for col in key_cols], names=key_cols
        )
    ).merge(df, on=key_cols, how="left")
    # an insignificatly small value
    df2["pop"] = df2["pop"].fillna(.001)
    
    px.scatter(df2, x="gdpPercap", y="lifeExp", animation_frame="year", animation_group="country",
               size="pop", color="continent", hover_name="country",
               log_x=True, size_max=55, range_x=[100,100000], range_y=[25,90])
    
    
    

    enter image description here