Search code examples
pythonplotlydata-visualization

How to add rectangles and text annotations in Plotly python?


Matplotlib has plt.Rectangle() to create a colored rectangle and ax.text to place texts for each rectangles added. Link here

enter image description here

Sample data

data_dict = {"Trace": [["A-M", "B&M", "B&Q", "BLOG", "BYPAS", "CIM"],
                       ["B&M", "B&Q", "BLOG", "BYPAS"],
                       ["BLOG", "BYPAS", "CIM"],
                       ["A-M", "B&M", "B&Q", "BLOG"]],
             "Percentage": [28.09, 32, 0.98, 18.68]}

acronym = {"A-M": "Alternating Maximization",
           "B&M": "Business & Management",
           "B&Q": "Batch-And-Queue",
           "BLOG": "Buy Locally Owned Group",
           "BYPAS": "Bypass",
           "CIM": "Common Information Model"
           }

Does plotly supports adding rectangles to the plot. How to plot a "Trace Explorer" kind of plot in plotly?


Solution

  • In order to have legend entries that are connected to the rectangles you draw, you'll need to use go.Scatter to draw the rectangles. Annotations won't work because they have no corresponding legend entries.

    Each rectangle will be drawn with a go.Scatter trace containing five (x,y) coordinates (from starting position back to original starting position) and we can fill it with a color mapping specific to its name. Since multiple rectangles have the same name, we will want to avoid duplicate entries by using legend groups.

    There's a few other things related to formatting such as the padding between rows, the width and height of the boxes, and setting the range of the y-axes so that selecting and deselecting traces doesn't resize the plot (default behavior of plotly which I assume isn't desirable here).

    import pandas as pd
    import plotly.graph_objects as go
    
    data_dict = {"Trace": [["A-M", "B&M", "B&Q", "BLOG", "BYPAS", "CIM"],
                           ["B&M", "B&Q", "BLOG", "BYPAS"],
                           ["BLOG", "BYPAS", "CIM"],
                           ["A-M", "B&M", "B&Q", "BLOG"]],
                 "Percentage": [28.09, 32, 0.98, 18.68]}
    
    acronym = {"A-M": "Alternating Maximization",
               "B&M": "Business & Management",
               "B&Q": "Batch-And-Queue",
               "BLOG": "Buy Locally Owned Group",
               "BYPAS": "Bypass",
               "CIM": "Common Information Model"
               }
    
    color_map = {"A-M": "DodgerBlue",
               "B&M": "DarkTurquoise",
               "B&Q": "Aquamarine",
               "BLOG": "LightGreen",
               "BYPAS": "Khaki",
               "CIM": "Tomato"
               }
    
    check_legend_entry = {key:False for key in acronym.keys()}
    
    fig = go.Figure()
    
    ## xaxis legnth is the number of categories + 1 for the percentage boxes
    xaxis_length = max([len(trace_list) for trace_list in data_dict['Trace']]) + 1
    width, height = 1, 1
    y_row_padding = width/4
    xaxis_padding = width/4
    
    ## draw out of the rectangles by iterating through each trace
    ## and plotting in coordinates starting from upper left to lower right
    ## the rectangles will be centered at (0,0), (1,0), ... (0,-1), (1,-1), ... ()
    for row_number, trace_list in enumerate(data_dict['Trace']):
    
        ## this will add y-padding between any boxes that aren't in the first row
        y_pos = (row_number-1)*(1+y_row_padding)
        for x_pos, name in enumerate(trace_list):
    
            ## check whether a legend entry has been created for a particular name
            ## to avoid duplicate legend entries for the same type of rectangle
    
            if check_legend_entry[name] == False:
                check_legend_entry[name] = True
                showlegend=True
            else:
                showlegend=False
            
            fig.add_trace(go.Scatter(
                x=[x_pos-width/2, x_pos+width/2, x_pos+width/2, x_pos-width/2, x_pos-width/2],
                y=[-y_pos-height/2, -y_pos-height/2, -y_pos+height/2, -y_pos+height/2, -y_pos-height/2],
                mode='lines',
                name=acronym[name],
                meta=[name],
                hovertemplate='%{meta[0]}<extra></extra>',
                legendgroup=acronym[name],
                line=dict(color="black"),
                fill='toself',
                fillcolor=color_map[name],
                showlegend=showlegend
            ))
    
            ## add the text in the center of each rectangle
            ## skip hoverinfo since the rectangle itself already has hoverinfo
            fig.add_trace(go.Scatter(
                x=[x_pos],
                y=[-y_pos],
                mode='text',
                legendgroup=acronym[name],
                text=[name],
                hoverinfo='skip',
                textposition="middle center",
                showlegend=False
            ))
    
    ## add the percentage boxes
    for row_number, percentage in enumerate(data_dict['Percentage']):
        y_pos = (row_number-1)*(1+y_row_padding)
        x_pos = max([len(trace_list) for trace_list in data_dict['Trace']]) + width/4
        fig.add_trace(go.Scatter(
            x=[x_pos-width/2, x_pos+width/2, x_pos+width/2, x_pos-width/2, x_pos-width/2],
            y=[-y_pos-height/2, -y_pos-height/2, -y_pos+height/2, -y_pos+height/2, -y_pos-height/2],
            mode='lines',
            line=dict(width=0),
            fill='toself',
            fillcolor='darkgrey',
            showlegend=False
        ))
        fig.add_trace(go.Scatter(
            x=[x_pos],
            y=[-y_pos],
            mode='text',
            text=[f"{percentage}%"],
            marker=dict(color="white"),
            hoverinfo='skip',
            textposition="middle center",
            showlegend=False
        ))
    
    ## prevent the axes from resizing if traces are removed
    fig.update_xaxes(range=[-width+xaxis_padding, xaxis_length-xaxis_padding])
    fig.update_layout(template='simple_white')
    fig.update_yaxes(visible=False)
    fig.show()
    

    enter image description here

    NOTE: I realize you did not ask for the functionality to select or deselect traces from the legend, but I don't believe it is possible to disable this in plotly-python even if you want to (see this open issue). This is what that functionality looks like:

    enter image description here