I am trying to set individual xlabels for each subplot in a seaborn parigrid object, but the plot wont update and just shows me the xlables for the bottom most plot only.
g = sns.PairGrid(dat,x_vars = inputs, y_vars = outputs, hue = 'variable')
def scatter_plt(x, y, *a, **kw):
if x.equals(y):
kw["color"] = (0, 0, 0, 0)
plt.scatter(x, y,*a, **kw)
plt.xticks(rotation=90)
plt.subplots_adjust(wspace=0.4, hspace=0.4)
g.map(scatter_plt)
I tried the following but it did not work as I saw the same plot as before.
xlabels,ylabels = [],[]
for ax in g.axes[-1,:]:
xlabel = ax.xaxis.get_label_text()
xlabels.append(xlabel)
for ax in g.axes[:,0]:
ylabel = ax.yaxis.get_label_text()
ylabels.append(ylabel)
for i in range(len(xlabels)):
for j in range(len(ylabels)):
g.axes[j,i].xaxis.set_label_text(xlabels[i])
g.axes[j,i].yaxis.set_label_text(ylabels[j])
Seaborn sets these internal labels invisible, so you explicitly need to set them visible again.
Here is how the code could look like. Some details have also changed:
iris
dataset is used for easy reproducibility.plt.subplots_adjust(...)
only needs to be called once, as it changes the full figure. Instead of plt.subplot_adjust()
, plt.tight_layout()
often works easier, as it tries to optimize all distances.xlabels
and ylabels
via list comprehension not only makes the code shorter, it also prevents errors and makes things easier to change.for i, xlabel in enumerate(xlabels)
are seen.import matplotlib.pyplot as plt
import seaborn as sns
iris = sns.load_dataset('iris')
g = sns.PairGrid(iris, x_vars=iris.columns[0:4], y_vars=iris.columns[0:3], hue='species')
def scatter_plt(x, y, *a, **kw):
if not x.equals(y):
plt.scatter(x, y, *a, **kw)
plt.tick_params(axis='x', rotation=90)
g.map(scatter_plt)
xlabels = [ax.xaxis.get_label_text() for ax in g.axes[-1, :]]
ylabels = [ax.yaxis.get_label_text() for ax in g.axes[:, 0]]
for i, xlabel in enumerate(xlabels):
for j, ylabel in enumerate(ylabels):
g.axes[j, i].set_xlabel(xlabel, visible=True)
g.axes[j, i].set_ylabel(ylabel, visible=True)
plt.tight_layout()
plt.show()