I am trying to plot a correlation matrix with values shows in plot. Instead of a single value of correlation coefficient, I want a range on the tiles (confidence interval). To achieve this I am first plotting the matrix with a colorbar using matshow and individual writing values of low and high intervals on the plot using centering option. Here is the snippet of that code.
cax = ax.matshow(data, interpolation='nearest', cmap=GnRd, vmin=-1, vmax=1 ,alpha=1 )
fig.colorbar(cax, ticks=[-1,0,1], shrink=0.8)
for (i, j), z in np.ndenumerate(data2):
if i>=j and z > -0.5 and z < 1:
ax.text(j, i, '{:0.2f}'.format(z), ha='left', va='center', size=28, color='black', **hfont)
if z < -0.5 :
ax.text(j, i, '{:0.2f}'.format(z), ha='left', va='center', size=28, color='black', fontweight='bold', **hfont)
if z == 1 :
ax.text(j, i, '{:0.2f}'.format(z), ha='center', va='center', size=28, color='black', fontweight='bold', **hfont)
for (i, j), z in np.ndenumerate(data3):
if i>=j and z > -0.5 and z <1:
ax.text(j, i, '{:0.2f}'.format(z), ha='right', va='center', size=28, color='black', **hfont)
if z < -0.5 :
ax.text(j, i, '{:0.2f}'.format(z), ha='right', va='center', size=28, color='black', fontweight='bold', **hfont)
plt.show()
and here is my plot :
The problem with this plot is coloring in the tiles is not correct and I am not sure how to color the tiles for a range. Also, there is a wasted space on the tiles which can be reduced by making titles rectangular but I believe matshow doesn't that option. I found a few work around based on object drawing, which will probably make my life more complicated. Any help will be appreciated.
I came up with an answer for my query. I changed my aspect='auto'
and tiles became retangular for the other issue colorbar issue, I decided to simply use the average of these number for the plotting :
cax = ax.matshow(data, interpolation='nearest', cmap=GnRd, vmin=-1, vmax=1 ,alpha=1, aspect='auto' )
fig.colorbar(cax, ticks=[-1,0,1], shrink=0.8)
for (i, j), z in np.ndenumerate(data2):
if i>=j and z > -0.5 and z < 1:
ax.text(j, i, '{:0.2f}'.format(z), ha='left', va='center', size=28, color='black', **hfont)
if z < -0.5 :
ax.text(j, i, '{:0.2f}'.format(z), ha='left', va='center', size=28, color='black', fontweight='bold', **hfont)
if z == 1 :
ax.text(j, i, '{:0.2f}'.format(z), ha='center', va='center', size=28, color='black', fontweight='bold', **hfont)
for (i, j), z in np.ndenumerate(data3):
if i>=j and z > -0.5 and z <1:
ax.text(j, i, '{:0.2f}'.format(z), ha='right', va='center', size=28, color='black', **hfont)
if z < -0.5 :
ax.text(j, i, '{:0.2f}'.format(z), ha='right', va='center', size=28, color='black', fontweight='bold', **hfont)
plt.show()