Search code examples
rggplot2tidyrgammgcv

How to visualize GAM results with contour & tile plot (using ggplot2)


I would like to make a contour plot with ggplot2 by using gam results. Below is a detailed explanation of what I want:

#packages
library(mgcv)
library(ggplot2)
library(tidyr)  
#prepare data
df <- data.frame(x = iris$Sepal.Width,
                 y = iris$Sepal.Length,
                 z = iris$Petal.Length)
#fit gam
gam_fit  <- gam(z ~
                  s(x) +
                  s(y),
                data=df,na.action = "na.fail")

To predict z values based on the gam_fit, I found a way from https://drmowinckels.io/blog/2019-11-16-plotting-gamm-interactions-with-ggplot2/

#predict z values
df_pred <- expand_grid(
  x = seq(from=min(df$x), 
              to=max(df$x), 
              length.out = 100),
  y = seq(from=min(df$y), 
              to=max(df$y), 
              length.out = 100)
)
df_pred <- predict(gam_fit, newdata = df_pred, 
                      se.fit = TRUE) %>%  
    as_tibble() %>% 
    cbind(df_pred)
gg <- ggplot() +
  geom_tile(data=df_pred, aes(x=x, y=y, fill = fit)) +
  geom_point(data=df,aes(x=x, y=y))+
  scale_fill_distiller(palette = "YlGnBu")+
  geom_contour(data=df_pred, aes(x=x, y=y, z = fit), colour = "white")
print(gg)

This give me a below plot enter image description here

My goal is removing tile and contour at where there are no measured x-y points. For example, there is no measured points around the top-right & top-left corners of the plot.

I wonder if mgcViz can achieve this, but it requires including x & y as an interaction term as below (also I am not sure how to add measured points on the below figure):

library(mgcViz)
gamm_fit2  <- gam(z ~
                   s(x,y),
                data=df,na.action = "na.fail") #,REML=TRUE
b <- getViz(gamm_fit2)
plot(sm(b, 1))

enter image description here

I think df_pred may not the best format to achieve my goal, but I am not sure how to do this. I would be grateful if you give me any solution with ggplot2.


Solution

  • To get something more akin to how mgcv::plot.gam() and mgcViz produce their plots for something like this, you need to identify pairs of covariates that lie too far from the support of your data. The reason we might prefer this over say clipping the predictions to the convex hull of the observations is that some mild extraxpolation beyond the data is probably not too much of a violation of the fact that splines have penalties that apply over the range of the data only. From a more pragmatic view, and this is something shown in the Anderson's Iris data used in the example, there are regions of the covariate space where we would have to interpolate that lie as far as, if not further from, the support of the data than point we might extrapolate to.

    mgcv has a function for doing this called exclude.too.far(), so if you want total control you can do, reusing code from @jared_mamrot's excellent answer (modified a little)

    library("dplyr")
    library("tidyr")
    library("ggplot2")
    library("mgcv")
    
    # prepare data
    df <- with(iris, data.frame(x = Sepal.Width,
                                y = Sepal.Length,
                                z = Petal.Length))
    #fit gam
    gam_fit  <- gam(z ~ s(x) + s(y), data = df, method = "REML")
    
    df_new <- with(df, expand_grid(x = seq(from = min(x), to = max(x),
                                           length.out = 100),
                                   y = seq(from = min(y), to = max(y), 
                                           length.out = 100)))
    
    df_pred <- predict(gam_fit, newdata = df_new)
    df_pred <- tibble(fitted = df_pred) |>
      bind_cols(df_new)
    

    Now we can find out which of our rows in the grid we're predicting at represent covariate pairs that are too far from the support of the original data. What exclude.too.far() does is transform the pairs of covariates in the prediction grid to a unit square, with [0,0] representing the coordinate (min(x), min(y)), and [1,1] the coordinate (max(x), max(y)). It transforms than original covariate data onto this unit square also. It then computes the euclidean distance between each point in the grid (on the unit square) and each row in the observed data (projected on to the unit square).

    Any observation that lies > dist from a node in the prediction grid is then identified to be excluded as lying too far from the support of the data. dist is the argument that controls what we mean by "too far". dist is specified in terms of the unit square, so the maximum any two points can be on the unit square is

    r$> dist(data.frame(x = c(0,1), y = c(0,1)))                                    
             1
    2 1.414214
    

    The default in plot.gam and IIRC in mgcvViz is dist = 0.1. If we do this for our example

    drop <- exclude.too.far(df_pred$x, df_pred$y, df$x, df$y, dist = 0.1)
    

    drop is now a logical vector of length nrow(df_pred), with TRUE indicating we should exclude the observation pair.

    Using drop we can set fitted to NA for the points we want to exclude:

    df_pred <- df_pred |>
      mutate(fitted = if_else(drop, NA_real_, fitted))
    

    Now we can plot:

    df_pred |>
    ggplot(aes(x = x, y = y, fill = fitted)) +
      geom_tile() +
      geom_point(data = df, aes(x = x, y = y, fill = NULL)) +
      scale_fill_distiller(palette = "YlGnBu") +
      geom_contour(aes(z = fitted, fill = NULL), colour = "white")
    

    producing

    enter image description here

    You can do this a bit more easily using my gratia package (IMHO), but the general idea is the same

    # remotes::install_github("gavinsimpson/gratia") # need's dev version
    library("gratia")
    
    # prepare data
    df <- with(iris, data.frame(x = Sepal.Width,
                                y = Sepal.Length,
                                z = Petal.Length))
    
    # fit model
    gam_fit  <- gam(z ~ s(x) + s(y), data = df, method = "REML")
    
    # prepare a data slice through the covariate space
    ds <- data_slice(gam_fit, x = evenly(x, n = 100), y = evenly(y, n = 100))
    
    # predict
    fv <- fitted_values(gam_fit, data = ds)
    
    # exclude points that are too far
    drop <- too_far(ds$x, ds$y, df$x, df$y, dist = 0.1)
    fv <- fv |>
      mutate(fitted = if_else(drop, NA_real_, fitted))
    
    # then plot
    fv |>
    ggplot(aes(x = x, y = y, fill = fitted)) +
      geom_tile() +
      geom_point(data = df, aes(x = x, y = y, fill = NULL)) +
      scale_fill_distiller(palette = "YlGnBu") +
      geom_contour(aes(z = fitted, fill = NULL), colour = "white")