My understanding is that rcs()
(from the rms
package) uses a truncated-power basis to represent natural (restricted) cubic splines. Alternatively, I could use ns()
(from the splines
package) that uses a B-spline basis.
However, I noticed that the training fits and testing predictions could be very different (especially when x
is extrapolated). I'm trying to understand the differences between rcs()
and ns()
and whether I could use the functions interchangeably.
Fake non-linear data.
library(tidyverse)
library(splines)
library(rms)
set.seed(100)
xx <- rnorm(1000)
yy <- 10 + 5*xx - 0.5*xx^2 - 2*xx^3 + rnorm(1000, 0, 4)
df <- data.frame(x=xx, y=yy)
Fit one model with ns
and another with rcs
with the same knots.
ns_mod <- lm(y ~ ns(x, knots=c(-2, 0, 2)), data=df)
ddist <- datadist(df)
options("datadist" = "ddist")
trunc_power_mod <- ols(y ~ rcs(x, knots=c(-2, 0, 2)), data=df)
Examine their fits (MSE).
mean(ns_mod$residuals^2)
mean(trunc_power_mod$residuals^2)
df$pred_ns <- ns_mod$fitted.values
df$pred_trunc_power <- trunc_power_mod$fitted.values
df_melt <- df %>%
gather(key="model", value="predictions", -x, -y)
ggplot(df_melt, aes(x=x, y=y)) +
geom_point(alpha=0.1) +
geom_line(aes(x=x, y=predictions, group=model, linetype=model))
Generate a test data set and plot the predictions between the two models.
newdata <- data.frame(x=seq(-10, 10, 0.1))
pred_ns_new <- predict(ns_mod, newdata=newdata)
pred_trunc_new <- predict(trunc_power_mod, newdata=newdata)
newdata$pred_ns_new <- pred_ns_new
newdata$pred_trunc_new <- pred_trunc_new
newdata_melted <- newdata %>%
gather(key="model", value="predictions", -x)
ggplot(newdata_melted, aes(x=x, y=predictions, group=model, linetype=model)) +
geom_line()
There's a fairly simple explanation: knots
is not an argument to rcs()
. It wants the knots to be specified using parameter parms
. Another issue is that the knots
parameter to ns()
doesn't specify the "boundary knots", which default to range(x)
. So to get the same predictions, you need
trunc_power_mod <- ols(y ~ rcs(x, parms=c(min(x), -2, 0, 2, max(x))), data=df)