Search code examples
rgammgcv

Find prediction of GAM by the estimated equation


Consider this code for fitting a GAM and then finding the prediction:

library('mgcv')

a0 = 1
a1 = 5

f2 = function(x) x^3

ysim1 = function(n = 500) {
  set.seed(10)
  x1 = runif(n)
  x2 = runif(n)
  e = rnorm(n)
  f = f2(x1)
  y = a0 + a1*x1 + f + e
  data.frame(y = y, x1, x2, f2 = f)
}

df1 = ysim1()
head(df1)

m = gam(y ~ x1 + s(x2), data = df1)
pred1 = predict(m)[1]

Here's the prediction:

> pred1
       1 
3.846924

How can I find this prediction by using the estimated equation formula and gam function outputs? Here is an example in case of a linear regression:

ysim2 = function(n = 500) {
 set.seed(10)
 x = runif(n)
 e = rnorm(n)
 y = a0 + a1*x + e
 data.frame(y, x)
}

df2 = ysim2()
head(df)

m2 = lm(y~x, df2)

pred2 = predict(m2)[1]

The prediction value is:

> pred2
     1 
3.5459 

We can also calculate pred2 by the estimated equation and lm output as:

coeff[[1]] + coeff[[2]]*df2$x[1]

What is the above similar code for gam?


Solution

  • You want the so-called Lp matrix, which is the same as the model matrix or design matrix of a LM or GLM but with all basis functions evaluated at the respective values of the covariates. For example:

    library('mgcv')
    
    a0 <- 1
    a1 <- 5
    
    f2 <- function(x) x^3
    
    ysim1 = function(n = 500) {
      set.seed(10)
      x1 <- runif(n)
      x2 <- runif(n)
      e <- rnorm(n)
      f <- f2(x1)
      y <- a0 + a1*x1 + f + e
      data.frame(y = y, x1, x2, f2 = f)
    }
    
    df1 <- ysim1()
    
    m <- gam(y ~ x1 + s(x2), data = df1)
    
    Xp <- predict(m, type = "lpmatrix")
    
    pred <- drop(Xp %*% coef(m))
    head(pred)
    head(predict(m)) # same
    
    r$> pred <- drop(Xp %*% coef(m))                                                
    r$> head(pred)                                                                  
           1        2        3        4        5        6 
    3.846924 2.651058 3.361998 4.927578 1.271078 2.139436 
    r$> head(predict(m)) # same                                                     
           1        2        3        4        5        6 
    3.846924 2.651058 3.361998 4.927578 1.271078 2.139436