Search code examples
pythonpandasmatplotlibseabornheatmap

How to plot columns with a value and x-y positions as a color grid in subplots


In R I would do the following to make a grid of facets with a raster-plot in each facet:

# R Code

DF <- data.frame(expand.grid(seq(0, 7), seq(0, 7), seq(0, 5)))
names(DF) <- c("x", "y", "z")
DF$I <- runif(nrow(DF), 0, 1)
#      x y z          I
#   1: 0 0 0 0.70252977
#   2: 1 0 0 0.74346071
#  ---                 
# 383: 6 7 5 0.93409337
# 384: 7 7 5 0.14143277
library(ggplot2)
ggplot(DF, aes(x = x, y = y, fill = I)) + 
  facet_wrap(~z, ncol = 3) +
  geom_raster() + 
  scale_fill_viridis_c() +
  theme(legend.position = "bottom") # desired legend position should be bottom

ggplot + facet_wrap + geom_raster

How can I do that in python (using matplotlib and probably seaborn)? I tried it with the following code, but had trouble with the plotting of images which I tried with plt.imshow. As the data has to be reshaped for plt.imshow I guess I need a custom plot function for g.map. I tried several things, but had problem with the Axes or the color and with using the data in the custom plot function.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import itertools

df = pd.DataFrame(list(itertools.product(range(8), range(8), range(6))), 
                  columns=['x', 'y', 'z'])
# order of values different than in R, but that shouldn't matter for plotting
df['I'] = np.random.rand(df.shape[0])
#      x  y  z         I
# 0    0  0  0  0.076338
# 1    0  0  1  0.148386
# 2    0  0  2  0.481053
# ..  .. .. ..       ...
# 382  7  7  4  0.144188
# 383  7  7  5  0.700624
g = sns.FacetGrid(df, col='z', col_wrap=2, height=4, aspect=1)
g.map(plt.imshow, color = 'I') # <- plt.imshow does not work here. 
# How can this be corrected (probably with a custom plot function)?
plt.show()

Solution

  • import seaborn as sns
    import numpy as np
    import pandas as pd
    import matplotlib.pyplot as plt
    
    # sample data
    df = pd.DataFrame(list(itertools.product(range(8), range(8), range(6))), 
                      columns=['x', 'y', 'z'])
    np.random.seed(20231116)  # for reproducible data
    df['I'] = np.random.rand(df.shape[0])
    
    # create the figure and axes
    fig, axes = plt.subplots(2, 3, figsize=(15, 10), sharex=True, sharey=True)
    
    # flatten the axes into a 1d array for easy access
    axes = axes.flat
    
    # add a separate axes for the colorbar
    cbar_ax = fig.add_axes([0.3, .03, .4, .03])
    
    # enumerate is specifically for adding the colorbar
    # zip each group of 'z' data to the appropriate axes
    for i, (ax, (z, data)) in enumerate(zip(axes, df.groupby('z'))):
    
        # pivot data into the correct shape for heatmap
        data = data.pivot(index='y', columns='x', values='I')
    
        # plot the heatmap
        sns.heatmap(data=data, cmap='viridis', ax=ax, cbar=i == 0, vmin=df.I.min(), vmax=df.I.max(),
                    cbar_ax=None if i else cbar_ax, cbar_kws=dict(location="bottom"))
    
        # add a title
        ax.set(title=f'Z: {z}')
    
        # invert the yaxis to match the OP
        ax.invert_yaxis()
    

    enter image description here

    data for z: 5

    x         0         1         2         3         4         5         6         7
    y                                                                                
    0  0.488408  0.855913  0.339374  0.452842  0.510380  0.690491  0.448773  0.500916
    1  0.273653  0.561840  0.860269  0.387470  0.170281  0.718488  0.256749  0.463527
    2  0.546085  0.093934  0.273339  0.503968  0.063212  0.537974  0.867814  0.135719
    3  0.071505  0.792265  0.919784  0.559663  0.733996  0.032003  0.475792  0.690789
    4  0.474310  0.265576  0.841875  0.496676  0.603356  0.328808  0.039460  0.461778
    5  0.439142  0.119253  0.842653  0.155213  0.798092  0.093709  0.899745  0.927067
    6  0.548373  0.259983  0.295939  0.700694  0.040197  0.679880  0.153048  0.328768
    7  0.216977  0.176777  0.238436  0.610802  0.705161  0.614877  0.813430  0.527120
    

    • Implementation with plt.figure and fig.add_subplot, instead of plt.subplots
    # create the figure and axes
    fig = plt.figure(figsize=(15, 10))
    
    # add a separate axes for the colorbar
    cbar_ax = fig.add_axes([0.3, .03, .4, .03])
    
    # enumerate is specifically for adding the colorbar and adding an axes
    for i, (z, data) in enumerate(df.groupby('z')):
    
        # pivot data into the correct shape for heatmap
        data = data.pivot(index='y', columns='x', values='I')
    
        # create the axes
        ax = fig.add_subplot(2, 3, i+1)
    
        # plot the heatmap
        sns.heatmap(data=data, cmap='viridis', ax=ax, cbar=i == 0, vmin=df.I.min(), vmax=df.I.max(),
                    cbar_ax=None if i else cbar_ax, cbar_kws=dict(location="bottom"))
    
        # add a title
        ax.set(title=f'Z: {z}')
    
        # invert the yaxis to match the OP
        ax.invert_yaxis()