Search code examples
pythonplotly-python

Change the grid color or change spacing of cells for Plotly imshow


I have generated a Plotly imshow figure, but I feel like it's a bit hard to read for my data since I have numerous categories.

I would like to either change the gird colours to help to define the cells or, as a perfered option, to add space between the cells to make the plot easier to read (add white space in-between the colours so they are not right beside each other).

I tried to use

fig.update_xaxes(showgrid=True, 
                 gridwidth=1,
                 gridcolor='blue')

but it doesn't work, neither does

fig.update_xaxes(visible=True)

I have no idea how I would add write space in between the cells (I'm not sure it's possible, but, I thought I'd ask)

My full code is below

import pandas as pd
import plotly.express as px

# Sample data
df = pd.DataFrame({'cat':['A','B','C'],
                   'att1':[1,0.55,0.15],
                   'att2':[0.55,0.3,0.55],
                   'att3':[0.55,0.55,0.17]
                  })
df =df.set_index('cat')

# Graphing
fig = px.imshow(df, text_auto=False,height=800, width=900,
                color_continuous_scale=px.colors.diverging.Picnic, 
                aspect='auto')

fig.update_layout(xaxis={'side': 'top'})

fig.update_xaxes(tickangle=-45,
                 showgrid=True, 
                 gridwidth=1,
                 gridcolor='blue')

fig.update(layout_coloraxis_showscale=False)
fig.update_xaxes(visible=True)
fig.show()


Solution

  • A rather inelegant solution would be to use scatter instead of imshow. This implies modifying your input dataframe so that att1, att2 and att3 are reshaped a single pandas series:

    df = pd.DataFrame({'cat':['A','B','C','A','B','C','A','B','C'],
                       'position' : [1, 1, 1, 2, 2, 2, 3, 3, 3],
                       'att':[1,0.55,0.15,0.55,0.3,0.55,0.55,0.55,0.17],
                      }) 
    

    You may then call scatter and update the traces to use square symbols instead of circles:

    import pandas as pd
    import matplotlib.pyplot as plt
    import plotly.express as px
    
    fig = px.scatter(data_frame=df, x='position', y='cat', color='att',color_continuous_scale=px.colors.diverging.Picnic , width=800, height=800)
    
    fig.update_traces(marker=dict(size=160, symbol = 'square',line=dict(width=2,color='DarkSlateGrey')))
    

    This snippet returns:

    enter image description here