I wanted to know if there was an easier way I could put a linear regression line on a plotly subplot. The code I made below does not appear to be efficient and it makes it difficult to add annotations to the graph for the linear trendlines, which I want placed on the graph. Furthermore, it is hard to make axes and titles with this code.
I was wondering if there was a way I could create a go.Figure and somehow put it on the subplot. I have tried that, but plotly will only allow me to put the data from the figure on the subplot rather than the actual Figure, so I lose the title, axis, and trendline information. In addition, the trendline is hidden on the graphs because the scatterplot is overlaid on top of it. I tried changing how the data was displayed with data=(data[1],data[0]), but that did not work.
Basically, I want to know if there is a more efficient way of putting a trendline on the scatter plots than I pursued, so I can make it easier to set axes, set the graph size, create legends, etc, since it is difficult to work with what I coded .
sheets_dict=pd.ExcelFile('10.05.22_EMS172LabReport1.xlsx')
sheets_list=np.array(sheets_dict.sheet_names[2:])
fig=make_subplots(rows=7,cols=1)
i=0
for name in sheets_list:
df=sheets_dict.parse(name)
df.columns=df.columns.str.replace(' ','')
df=df.drop(df.index[0])
slope,y_int=np.polyfit(df.CURR1,df.VOLT1,1)
LR="Linear Fit: {:,.3e}x + {:,.3e}".format(slope,y_int)
rmse=np.sqrt(sum(slope*df.CURR1+y_int-df.VOLT1)**2)
df['Best Fit']=slope*df.CURR1+y_int
i+=1
fig.add_trace(
go.Scatter(name='Best Fit Line'+" ± {:,.3e}V".format(rmse),x=df['CURR1'],y=df['Best Fit'],
mode='lines',line_color='red',line_width=2),row=i, col=1)
fig.add_trace(
go.Scatter(name='Voltage',x=df['CURR1'],y=df['VOLT1'],mode='markers'),
row=i, col=1)
# fig.data = (fig.data[1],fig.data[0])
fig.show()
Trendlines are implemented in plotly.express with extensive functionality. See here. It is possible to create a subplot using that graph data, but I have created a subplot with a graph object to take advantage of your current code.
Since you did not provide specific data, I used the example data in ref. It is a data frame showing the rate of change in stock prices for several companies. It is in the form of a trend line added to it.
As for the graph, I have changed the height because a subplot requires height. The addition of axis labels for each subplot is specified in a matrix. If you need axis titles for all subplots, add them. Also, as a customization of the legend, we have grouped A group for the torrent lines and a group for the rate of change. As an example of the annotations, the slope values are set to 0 on the x-axis of each subplot and the y-axis is set to the position of the maximum value of each value.
import plotly.express as px
import plotly.graph_objects as go
import numpy as np
df = px.data.stocks()
df.head()
date GOOG AAPL AMZN FB NFLX MSFT
0 2018-01-01 1.000000 1.000000 1.000000 1.000000 1.000000 1.000000
1 2018-01-08 1.018172 1.011943 1.061881 0.959968 1.053526 1.015988
2 2018-01-15 1.032008 1.019771 1.053240 0.970243 1.049860 1.020524
3 2018-01-22 1.066783 0.980057 1.140676 1.016858 1.307681 1.066561
4 2018-01-29 1.008773 0.917143 1.163374 1.018357 1.273537 1.040708
from plotly.subplots import make_subplots
fig = make_subplots(rows=6,cols=1, subplot_titles=df.columns[1:].tolist())
for i,c in enumerate(df.columns[1:]):
dff = df[[c]].copy()
slope,y_int=np.polyfit(dff.index, dff[c], 1)
LR="Linear Fit: {:,.3e}x + {:,.3e}".format(slope,y_int)
rmse=np.sqrt(sum(slope*dff.index+y_int-df[c])**2)
dff['Best Fit'] = slope*df.index+y_int
fig.add_trace(go.Scatter(
name='Best Fit Line'+" ± {:,.3e}V".format(rmse),
x=dff.index,
y=dff['Best Fit'],
mode='lines',
line_color='blue',
line_width=2,
legendgroup='group1',
legendgrouptitle_text='Trendline'), row=i+1, col=1)
fig.add_trace(go.Scatter(
x=dff.index,
y=dff[c],
legendgroup='group2',
legendgrouptitle_text='Rate of change',
mode='markers+lines', name=c), row=i+1, col=1)
fig.add_annotation(x=0.1,
y=dff[c].max(),
xref='x',
yref='y',
text='{:,.3e}'.format(rmse),
showarrow=False,
yshift=5, row=i+1, col=1)
fig.update_layout(autosize=True, height=800, title_text="Stock and Trendline")
fig.update_xaxes(title_text="index", row=6, col=1)
fig.update_yaxes(title_text="Rate of change", row=3, col=1)
fig.show()