I want to plot multiple confusion matrices in a single plot with a single colorbar and with a shared x- and y-axis. Here is my code I have tried so far
#Calculate the onfusion matrices
predicted_mod1 = df_binary["Model1"]
actual_class = df_binary["Observed"]
out_df_mod1 = pd.DataFrame(np.vstack([predicted_mod1, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod1 = pd.crosstab(out_df_mod1['actual_class'], out_df_mod1['predicted_class'], rownames=['Actual'], colnames=['Predicted'])
predicted_mod2 = df_binary["Model2"]
out_df_mod2 = pd.DataFrame(np.vstack([predicted_mod2, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod2 = pd.crosstab(out_df_mod2['actual_class'], out_df_mod2['predicted_class'], rownames=['Actual'], colnames=['Predicted'])
predicted_mod4 = df_binary["Model4"]
out_df_mod4 = pd.DataFrame(np.vstack([predicted_mod4, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod4 = pd.crosstab(out_df_mod4['actual_class'], out_df_mod4['predicted_class'], rownames=['Actual'], colnames=['Predicted'])
predicted_mod5 = df_binary["Model5"]
out_df_mod5 = pd.DataFrame(np.vstack([predicted_mod5, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod5 = pd.crosstab(out_df_mod5['actual_class'], out_df_mod5['predicted_class'], rownames=['Actual'], colnames=['Predicted'])
predicted_mod6 = df_binary["Model6"]
out_df_mod6 = pd.DataFrame(np.vstack([predicted_mod6, actual_class]).T,columns=['predicted_class','actual_class'])
CF_mod6 = pd.crosstab(out_df_mod6['actual_class'], out_df_mod6['predicted_class'], rownames=['Actual'], colnames=['Predicted'])
Now I have plotted these matrices using the following code
fig = plt.figure(figsize=(6, 3), dpi=300)
fig.subplots_adjust(hspace=0.8, wspace=0.6)
ax = fig.add_subplot(2, 3, 1)
sns.heatmap(CF_mod1, cmap='Blues', annot=True, fmt='d')
ax = fig.add_subplot(2, 3, 2)
sns.heatmap(CF_mod2, cmap='Blues', annot=True, fmt='d')
ax = fig.add_subplot(2, 3, 3)
sns.heatmap(CF_mod3, cmap='Blues', annot=True, fmt='d')
ax = fig.add_subplot(2, 3, 4)
sns.heatmap(CF_mod4, cmap='Blues', annot=True, fmt='d')
ax = fig.add_subplot(2, 3, 5)
sns.heatmap(CF_mod5, cmap='Blues', annot=True, fmt='d')
ax = fig.add_subplot(2, 3, 6)
sns.heatmap(CF_mod6, cmap='Blues', annot=True, fmt='d')
plt.show()
My expected output is something like the following
Now how can I have only one single colorbar with a shared x- and y-axis?
Data
Model1,Model2,Model3,Model4,Model5,Model6,Observed
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
No,No,No,No,No,No,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,No,Yes,No,Yes,Yes
No,Yes,No,No,No,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,No,No,No,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,No,Yes,Yes,Yes,No,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,No,Yes,Yes,Yes,No,Yes
Yes,No,Yes,Yes,Yes,No,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
Yes,Yes,Yes,Yes,Yes,Yes,Yes
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
Yes,Yes,Yes,Yes,Yes,Yes,No
No,No,No,No,No,No,No
No,Yes,No,No,No,Yes,No
No,Yes,No,No,No,Yes,No
Yes,Yes,Yes,Yes,Yes,Yes,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,Yes,No,Yes,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
Yes,Yes,Yes,Yes,Yes,Yes,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
No,No,No,No,No,No,No
Following the logic from this answer, you can then loop through the subplots (with sharex=True
and sharey=True
to remove the ticks from plots not on the edges), plot the data, and remove the ylabel
if it's not in the first column and/or xlabel
if it's not in the last row. To ensure the right color scale, a global vmin
and vmax
is computed before the loop.
nrows = 2
ncols = 3
fig, axes = plt.subplots(nrows, ncols, sharex=True, sharey=True)
cbar_ax = fig.add_axes([0.91, 0.3, 0.03, 0.4])
data = [CF_mod1, CF_mod2, CF_mod3, CF_mod4, CF_mod5, CF_mod6]
# get global min and max to enforce the same colorscale in all plots
vmin = min([d.min().min() for d in data])
vmax = max([d.max().max() for d in data])
for i, (ax, d) in enumerate(zip(axes.flat, data)):
p = sns.heatmap(d, ax=ax, annot=True,
vmin=vmin, vmax=vmax,
cmap="Blues", cbar=(i==0), cbar_ax=None if i else cbar_ax)
# remove ylabel if not in the first column
if i%ncols:
ax.set_ylabel("")
# remove xlabel if not in the last row
if i//ncols + 1 != nrows:
ax.set_xlabel("")
fig.show()
Result:
For the axes labels, you can also use suplabels and remove the individual axes labels.
for i, (ax, d) in enumerate(zip(axes.flat, data)):
p = sns.heatmap(d, ax=ax, annot=True,
vmin=vmin, vmax=vmax,
cmap="Blues", cbar=False)
ax.set_xlabel("")
ax.set_ylabel("")
fig.supxlabel("Predicted")
fig.supylabel("Actual")
Result:
Edit: To put a title above each plot simply add ax.set_title
to the loop.
for i, (ax, d) in enumerate(zip(axes.flat, data)):
p = sns.heatmap(d, ax=ax, annot=True,
vmin=vmin, vmax=vmax,
cmap="Blues", cbar=(i==0), cbar_ax=None if i else cbar_ax)
ax.set_xlabel("")
ax.set_ylabel("")
ax.set_title(f"Model {i+1}")
Result:
Edit: To automate the titles, use the dataframe columns.
for i, (ax, d, title) in enumerate(zip(axes.flat, data, df_binary.columns)):
...
ax.set_title(title)