Search code examples
pythonpandasplotlyplotly-python

How to create subplots from each column in a pandas dataframe


I have a dataframe 'df' with 36 columns, these columns are plotted onto a single plotly chart and displayed in html format using the code below.

import plotly.offline as py
import plotly.io as pio

pio.write_html(py.offline.plot([{
'x': df.index,
'y': df[col],
'name': col
}for col in trend_data.columns], filename=new_file_path))

I want to iterate through each column and create a subplot for each one. I have tried;

from plotly.subplots import make_subplots

sub_titles = df.columns()
fig = make_subplots(rows=6, cols=6, start_cell="bottom-left", subplot_titles=sub_titles)
for i in df.columns:
    fig.add_trace(i)

I created 6 rows and columns as that would give 36 plots and tried to use the header names as subplot titles but I get a ValueError stating it was expecting a 2d list of dictionaries.

Also, I have tried to add subplot titles by;

sub_titles = list(df)
fig = py.subplots.make_subplots(rows=6, cols=6, sub_titles=sub_titles)

This also returns an error. Any help is appreciatted.


Solution

  • Plot:

    enter image description here

    Code:

    # imports
    from plotly.subplots import make_subplots
    import plotly.graph_objs as go
    import pandas as pd
    import numpy as np
    
    # data
    np.random.seed(123)
    frame_rows = 50
    n_plots = 36
    frame_columns = ['V_'+str(e) for e in list(range(n_plots+1))]
    df = pd.DataFrame(np.random.uniform(-10,10,size=(frame_rows, len(frame_columns))),
                      index=pd.date_range('1/1/2020', periods=frame_rows),
                        columns=frame_columns)
    df=df.cumsum()+100
    df.iloc[0]=100
    
    # plotly setup
    plot_rows=6
    plot_cols=6
    fig = make_subplots(rows=plot_rows, cols=plot_cols)
    
    # add traces
    x = 0
    for i in range(1, plot_rows + 1):
        for j in range(1, plot_cols + 1):
            #print(str(i)+ ', ' + str(j))
            fig.add_trace(go.Scatter(x=df.index, y=df[df.columns[x]].values,
                                     name = df.columns[x],
                                     mode = 'lines'),
                         row=i,
                         col=j)
    
            x=x+1
    
    # Format and show fig
    fig.update_layout(height=1200, width=1200)
    fig.show()
    

    Addition: 1-column solution:

    Code:

    # imports
    from plotly.subplots import make_subplots
    import plotly.graph_objs as go
    import pandas as pd
    import numpy as np
    
    # data
    np.random.seed(123)
    frame_rows = 50
    frame_columns = ['V_'+str(e) for e in list(range(1,37))]
    df = pd.DataFrame(np.random.uniform(-8,10,size=(frame_rows, len(frame_columns))),
                      index=pd.date_range('1/1/2020', periods=frame_rows),
                        columns=frame_columns)
    df=df.cumsum()+100
    df.iloc[0]=100
    
    # plotly setup
    plot_rows=6
    plot_cols=6
    
    lst1 = list(range(1,plot_rows+1))
    lst2 = list(range(1,plot_cols+1))
    
    fig = make_subplots(rows=36, cols=1, subplot_titles=df.columns, insets=[{'l': 0.1, 'b': 0.1, 'h':1}])
    
    # add traces
    x = 1
    for i in lst1:
        for j in lst2:
            #print(str(i)+ ', ' + str(j))
            fig.add_trace(go.Scatter(x=df.index, y=df[df.columns[x-1]].values,
                                     name = df.columns[x-1],
                                     mode = 'lines',
                                     ),
    
                          row=x,
                         col=1)
    
            x=x+1
    
    fig.update_layout(height=12000, width=1200)
    
    fig.show()
    

    Plots:

    enter image description here