Search code examples
arraysdataframeplotlyrgbimshow

Converting an array of RGB colours to a px.imshow heatmap


I currently have a DF of different RGB colour values in this format:

          Protein_ID1  Protein_ID2  Protein_ID3
Module1    [R, G, B]    [R, G, B]    [R, G, B]    
Module2    [R, G, B]    [R, G, B]    [R, G, B]    
Module3    [R, G, B]    [R, G, B]    [R, G, B] 

I would like to display this with px.imshow as a eat map, with the cell colour corresponding to the RGB value.

When I do:

fig = px.imshow(df)
fig.update_layout(
                  xaxis=dict(
                             rangeslider=dict(visible=True)
                            )
                 )
fig.write_html(results_file)

I get nothing in the blank results file. Based on the first example here, I converted my df to an array as below and still had no luck:

array = df.to_numpy()
fig = px.imshow(array, x = df.columns, y = df.index)
fig.update_layout(
                  xaxis=dict(
                             rangeslider=dict(visible=True)
                            )
                 )
fig.write_html(results_file)

Can anyone shed some light on the correct way to approach this?

Thanks! Tim


Solution

    • core is get input to px.imshow() correct. It needs to be a 3D numpy array of type uint8
    • hence get values out of data frame and restructure to input requirements
    import pandas as pd
    import numpy as np
    import plotly.express as px
    
    # simulate data frame
    df = pd.DataFrame(
        np.random.randint(0, 255, [10, 10, 3]).tolist(),
        columns=[f"Protein_ID{i}" for i in range(10)],
        index=[f"Module{i}" for i in range(10)],
    )
    
    px.imshow(np.array(df.values.tolist(), dtype=np.uint8)).show()
    
    print(df.iloc[0:3, 0:3].to_markdown())
    

    sample data

    Protein_ID0 Protein_ID1 Protein_ID2
    Module0 [232, 78, 62] [96, 105, 104] [138, 63, 46]
    Module1 [143, 49, 25] [190, 70, 138] [77, 170, 155]
    Module2 [16, 209, 3] [153, 215, 47] [216, 246, 121]

    image

    enter image description here

    with labels

    
    px.imshow(np.array(df.values.tolist(), dtype=np.uint8)).update_layout(
        xaxis={"tickformat": "array", "tickvals": list(range(10)), "ticktext": df.columns},
        yaxis={"tickformat": "array", "tickvals": list(range(10)), "ticktext": df.index}
    
    )