Search code examples
pythonplotly

How to use plotly.express.imshow facet_row argument?


Consider the following code:

import pandas as pd
import numpy as np
import plotly.express as px

# Create the index for the data frame
x = np.linspace(-1,1, 6)
y = np.linspace(-1,1,6)
n_channel = [1, 2, 3, 4]

xx, yy = np.meshgrid(x, y)

zzz = np.random.randn(len(y)*len(n_channel),len(x))

df = pd.DataFrame(
    zzz,
    columns = pd.Index(x, name='x (m)'),
    index = pd.MultiIndex.from_product([y, n_channel], names=['y (m)', 'n_channel']),
)

print(df)

fig = px.imshow(
    df.reset_index('n_channel'),
    facet_col = 'n_channel',
)
fig.write_html(
    'plot.html',
    include_plotlyjs = 'cdn',
)

The data frame looks like this:

x (m)                -1.0      -0.6      -0.2       0.2       0.6       1.0
y (m) n_channel                                                            
-1.0  1         -0.492584  0.599464  0.097405 -0.177793 -0.027311  1.468527
      2          0.202147  0.449809 -2.047460 -1.392223  0.245228  1.220419
      3          0.139111 -0.699596  1.754103 -0.141732 -1.494373 -0.003184
      4          0.124390  0.245113 -0.031949  1.938560  1.418563 -0.787295
-0.6  1          1.112547  0.307750 -1.206242 -0.739546  0.038905 -0.923485
      2         -0.900733 -1.094717  0.770876 -1.973305  2.677651  3.072124
      3         -0.279864 -1.341024  2.750811 -1.401604  0.929714  0.658087
      4         -1.038905 -1.038625  0.112878  1.112139 -0.799305 -0.934813
-0.2  1          0.332704  1.321129  0.241799 -1.100657 -0.927649 -1.928624
      2         -0.576210  0.257960 -0.196699 -0.245751  0.575648 -0.703353
      3         -0.549881 -1.208282  0.959120  1.852333  1.452697 -0.562802
      4         -0.433256 -0.339644 -1.636592 -1.022501 -0.614497  1.085253
 0.2  1          0.378474 -0.829495 -1.313322 -0.654698 -0.644115  2.175938
      2          0.567393 -0.340301  1.304942  0.197879  0.309288 -0.126187
      3          0.209954  0.161299 -0.362754 -0.328356 -0.106934 -0.238329
      4         -0.284447 -0.367920 -0.275830 -0.776649  0.656279  0.056389
 0.6  1          1.174153 -1.112658  1.245117 -0.395144  0.471050  0.165074
      2         -0.220246  1.063194  0.292873  0.266250 -0.175274  0.225985
      3          0.301462  0.737581  0.271691  0.936558  1.007112  1.857389
      4         -0.689441  3.369569  0.675700  0.077706  0.152062 -0.533258
 1.0  1          0.732183  0.041873  1.156681  0.841262 -0.984433  1.313900
      2          0.157533  0.723356 -0.786721  0.150939  0.164049 -0.351816
      3         -0.390037 -1.513096  0.255813 -1.365759  0.570145  1.630885
      4          0.318037 -1.103191  1.472340 -0.218038  0.990673 -1.565340

and I expect it to produce 4 heatmaps, each of them similar to this one:

enter image description here

but instead I get AttributeError: 'DataFrame' object has no attribute 'dims'. If instead I do like this

for n in n_channel:
    fig = px.imshow(
        df.query(f'n_channel=={n}').reset_index('n_channel', drop=True),
    )
    fig.write_html(
        f'plot_{n}.html',
        include_plotlyjs = 'cdn',
    )

then this produces the 4 plots, but separated and (of course) with the axes not connected.

Is it possible to use the facet_row argument based on one column, similarly as it can be done e.g. with px.scatter?


Solution

  • To create a faceted graph in imshow, you need 3D data. You then specify which 3-dimensional axes to facet. See this reference. The 3D data is created by combining sample data and specifying the dimensions. I also modify the title of the facet graph to match the user data.

    import pandas as pd
    import numpy as np
    import plotly.express as px
    
    # Create the index for the data frame
    x = np.linspace(-1,1, 6)
    y = np.linspace(-1,1,6)
    n_channel = [1, 2, 3, 4]
    
    xx, yy = np.meshgrid(x, y)
    
    all_arrays = np.empty((0, 6, 6))
    
    for n in n_channel:
        z = np.random.randn(len(y), len(x))
        all_arrays = np.vstack((all_arrays, [z]))
    
    fig = px.imshow(
        all_arrays,
        facet_col=0,
    )
    
    for k in range(4):
        fig.layout.annotations[k].update(text='n_channel:{}'.format(k))
        
    fig.show()
    

    enter image description here