Search code examples
plotlyggplotlyplotly-python

Annotations varying by subplot with plotly-express?


Suppose I want to make a subplot like this in which each facet has its own y scale:

import plotly.express as px

fig = px.scatter(px.data.iris(), x='sepal_length', y='sepal_width', facet_col='species')


def update(y):
    y.update(matches=None)
    y.showticklabels=True

fig.for_each_yaxis(update)

enter image description here

Now suppose I want to add some annotations, and the position is going to vary according to the faceted variable, and I have this in a dataframe:

enter image description here

If I was using plotnine/ggplot I could do it like this:

ggplot(df_iris, aes(x='sepal_length', y='sepal_width')) + geom_point() + facet_wrap("~species", scales="free_y") + geom_text(aes(x='x', y='y', label='label'), data=df_text) 

enter image description here

Is it possible to do this in plotly? I got pretty bogged down mucking around with subplots and annotations, I know you can add annotations to a subplot, but you have to know the row and column number in order to do that, and I'm not sure how I can map the facet variable (species) to the subplot row/column indexes.

Thanks :)


Solution

  • I'm not sure if this is the best way to do it but you can try the following

    import plotly.express as px
    import pandas as pd
    
    df_text = pd.DataFrame({"species":["setosa", "versicolor", "virginica"],
                            "x": [7, 7, 5],
                            "y": [3, 2, 3.5],
                            "label":["label1", "label2", "label3"]})
    
    fig = px.scatter(px.data.iris(),
                     x='sepal_length',
                     y='sepal_width',
                     facet_col='species')
    
    # Here are your annotations
    data = px.scatter(df_text,
                      x="x",
                      y="y",
                      text="label",
                      facet_col='species')\
              .update_traces(mode="text")["data"]
    
    def update(y):
        y.update(matches=None)
        y.showticklabels=True
        
    fig.for_each_yaxis(update)
    
    for trace in data:
        fig.add_trace(trace)
    
    fig.show()
    

    enter image description here