i plot a multiple plot using seabor lmplot and i want to add a x=y line to this plot, can you help me to solve this problem?
my code :
sns.set_theme(style="white")
sns.lmplot(data=data, x='Target',y='Predicted', hue="Type",col='Model', height=5,legend=False, palette=dict(Train="g", Test="m"))
plt.plot([data.iloc[:,0].min(), data.iloc[:,0].max()], [data.iloc[:,0].min(), data.iloc[:,0].max()], "--", label="Perfect model")
plt.legend(loc='upper left')
plt.show()
i plot a multiple plot using seabor lmplot and i want to add a x=y line to this plot, can you help me to solve this problem?
my code :
sns.set_theme(style="white")
sns.lmplot(data=data, x='Target',y='Predicted', hue="Type",col='Model', height=5,legend=False, palette=dict(Train="g", Test="m"))
plt.plot([data.iloc[:,0].min(), data.iloc[:,0].max()], [data.iloc[:,0].min(), data.iloc[:,0].max()], "--", label="Perfect model")
plt.legend(loc='upper left')
plt.show()
The plt.plot()
that you are using will only add line to the last plot. Do add the line to each line, you will need to use the axes for the lmplot()
and plot a line for each of the subplots. As I don't have your data, used the standard penguins dataset to show this. Hope this helps...
data = sns.load_dataset('penguins') ## My data
sns.set_theme(style="white")
g=sns.lmplot(data=data, x='bill_length_mm',y='bill_depth_mm', hue="species", col="sex", height=5,legend=False, palette=dict(Adelie="g", Chinstrap="m", Gentoo='orange'))
axes = g.fig.axes ## Get the axes for all the subplots
for ax in axes: ## For each subplot, draw the line
ax.plot([data.iloc[:,2].min(), data.iloc[:,2].max()], [data.iloc[:,3].min(), data.iloc[:,3].max()], "--", label="Perfect model")
plt.legend(loc='upper left')
plt.show()