In R I would do the following to make a grid of facets with a raster-plot in each facet:
# R Code
DF <- data.frame(expand.grid(seq(0, 7), seq(0, 7), seq(0, 5)))
names(DF) <- c("x", "y", "z")
DF$I <- runif(nrow(DF), 0, 1)
# x y z I
# 1: 0 0 0 0.70252977
# 2: 1 0 0 0.74346071
# ---
# 383: 6 7 5 0.93409337
# 384: 7 7 5 0.14143277
library(ggplot2)
ggplot(DF, aes(x = x, y = y, fill = I)) +
facet_wrap(~z, ncol = 3) +
geom_raster() +
scale_fill_viridis_c() +
theme(legend.position = "bottom") # desired legend position should be bottom
How can I do that in python (using matplotlib and probably seaborn)? I tried it with the following code, but had trouble with the plotting of images which I tried with plt.imshow
. As the data has to be reshaped for plt.imshow
I guess I need a custom plot function for g.map
. I tried several things, but had problem with the Axes or the color and with using the data in the custom plot function.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import itertools
df = pd.DataFrame(list(itertools.product(range(8), range(8), range(6))),
columns=['x', 'y', 'z'])
# order of values different than in R, but that shouldn't matter for plotting
df['I'] = np.random.rand(df.shape[0])
# x y z I
# 0 0 0 0 0.076338
# 1 0 0 1 0.148386
# 2 0 0 2 0.481053
# .. .. .. .. ...
# 382 7 7 4 0.144188
# 383 7 7 5 0.700624
g = sns.FacetGrid(df, col='z', col_wrap=2, height=4, aspect=1)
g.map(plt.imshow, color = 'I') # <- plt.imshow does not work here.
# How can this be corrected (probably with a custom plot function)?
plt.show()
'z'
data with pandas.DataFrame.pivot
into the correct format for seaborn.heatmap
.
vmin
and vmax
with the min
and max
of the entire dataset: vmin=df.I.min()
and vmax=df.I.max()
fig
and all axes
with plt.subplots
.
fig = plt.figure()
and adding subplots with fig.add_subplot(2, 3, idx)
.python v3.12.0
, pandas v2.1.2
, matplotlib v3.8.1
, seaborn v0.13.0
.import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# sample data
df = pd.DataFrame(list(itertools.product(range(8), range(8), range(6))),
columns=['x', 'y', 'z'])
np.random.seed(20231116) # for reproducible data
df['I'] = np.random.rand(df.shape[0])
# create the figure and axes
fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True, sharey=True)
# flatten the axes into a 1d array for easy access
axes = axes.flat
# add a separate axes for the colorbar
cbar_ax = fig.add_axes([0.3, .03, .4, .03])
# enumerate is specifically for adding the colorbar
# zip each group of 'z' data to the appropriate axes
for i, (ax, (z, data)) in enumerate(zip(axes, df.groupby('z'))):
# pivot data into the correct shape for heatmap
data = data.pivot(index='y', columns='x', values='I')
# plot the heatmap
sns.heatmap(data=data, cmap='viridis', ax=ax, cbar=i == 0, vmin=df.I.min(), vmax=df.I.max(),
cbar_ax=None if i else cbar_ax, cbar_kws=dict(location="bottom"))
# add a title
ax.set(title=f'Z: {z}')
# invert the yaxis to match the OP
ax.invert_yaxis()
data
for z: 5
x 0 1 2 3 4 5 6 7
y
0 0.488408 0.855913 0.339374 0.452842 0.510380 0.690491 0.448773 0.500916
1 0.273653 0.561840 0.860269 0.387470 0.170281 0.718488 0.256749 0.463527
2 0.546085 0.093934 0.273339 0.503968 0.063212 0.537974 0.867814 0.135719
3 0.071505 0.792265 0.919784 0.559663 0.733996 0.032003 0.475792 0.690789
4 0.474310 0.265576 0.841875 0.496676 0.603356 0.328808 0.039460 0.461778
5 0.439142 0.119253 0.842653 0.155213 0.798092 0.093709 0.899745 0.927067
6 0.548373 0.259983 0.295939 0.700694 0.040197 0.679880 0.153048 0.328768
7 0.216977 0.176777 0.238436 0.610802 0.705161 0.614877 0.813430 0.527120
plt.figure
and fig.add_subplot
, instead of plt.subplots
# create the figure and axes
fig = plt.figure(figsize=(15, 10))
# add a separate axes for the colorbar
cbar_ax = fig.add_axes([0.3, .03, .4, .03])
# enumerate is specifically for adding the colorbar and adding an axes
for i, (z, data) in enumerate(df.groupby('z')):
# pivot data into the correct shape for heatmap
data = data.pivot(index='y', columns='x', values='I')
# create the axes
ax = fig.add_subplot(2, 3, i+1)
# plot the heatmap
sns.heatmap(data=data, cmap='viridis', ax=ax, cbar=i == 0, vmin=df.I.min(), vmax=df.I.max(),
cbar_ax=None if i else cbar_ax, cbar_kws=dict(location="bottom"))
# add a title
ax.set(title=f'Z: {z}')
# invert the yaxis to match the OP
ax.invert_yaxis()