Search code examples
rperformance

Faster smsurv function


I am trying to make an R function (much) more efficient. Find a working example below.

    smsurv <- function(Time,Status,X,beta,w,model){    
    death_point <- sort(unique(subset(Time, Status==1)))
    if(model=='ph') coxexp <- exp((beta)%*%t(X[,-1]))  
    n <- length(death_point)
    lambda <- numeric(n)
    for(i in 1: n){
      if(model=='ph')  temp <- sum(as.numeric(Time>=death_point[i])*w*drop(coxexp))
      if(model=='aft')  temp <- sum(as.numeric(Time>=death_point[i])*w)
      lambda[i] <- sum(Status*as.numeric(Time==death_point[i]))/temp
    }
    HHazard <- numeric()
    for(i in 1:length(Time)){
      HHazard[i] <- sum(as.numeric(Time[i]>=death_point)*lambda)
      if(Time[i]>max(death_point))HHazard[i] <- Inf
      if(Time[i]<min(death_point))HHazard[i] <- 0
    }
    survival <- exp(-HHazard)
    list(survival=survival)
  }

nr_obs = 50000

Time_input <- rnorm(nr_obs, mean = 100, sd = 36)
Status_input <- sample(c(0,1), replace=TRUE, size=nr_obs)
w_input <- Status_input

# Let's suppose there are 9 variables (first column denotes the intercept)
n_variables <- 9
X_input <- matrix(rnorm(nr_obs*n_variables),nr_obs)
X_input <- cbind(Intercept = rep(1, nrow(X_input)), X_input) 

beta_input <- runif(n_variables, min = -1, max = 1)
model_input <- "ph"
output <- smsurv(Time_input,Status_input,X_input,beta_input,w_input,model_input)

I have already attempted to replace the for-loops with lapply and sapply, but this actually made the function even slower:

    smsurv2 <- function(Time,Status,X,beta,w,model){    
    death_point <- sort(unique(subset(Time, Status==1)))
    if(model=='ph') coxexp <- exp((beta)%*%t(X[,-1]))  
    if(model_input=='ph') lambda =unlist(lapply(death_point, function(z) sum(Status_input*as.numeric(Time_input==z))/ sum(as.numeric(Time_input>=z)*w_input*drop(coxexp))))
    if(model=='aft') lambda =unlist( lapply(death_point, function(z) sum(Status_input*as.numeric(Time_input==z))/ sum(as.numeric(Time_input>=z)*w_input)))
    HHazard <- unlist(lapply(Time, function(t) {sum(as.numeric(t>=death_point)*lambda)}))
    HHazard[Time > max(death_point)] <- Inf
    HHazard[Time < min(death_point)] <- 0

    survival <- exp(-HHazard)
    list(survival=survival)
  }

smsurv3 <- function(Time, Status, X, beta, w, model){
  death_point <- sort(unique(subset(Time, Status==1)))
  if(model=='ph') coxexp <- exp((beta)%*%t(X[,-1]))
  lambda <- sapply(death_point, function(dp) {return(sum(Status*as.numeric(Time==dp))/sum(as.numeric(Time>=dp)*w*drop(coxexp)))})
  HHazard <- sapply(Time, function(t){return(sum(as.numeric(t>=death_point)*lambda))})
  HHazard[Time > max(death_point)] <- Inf
  HHazard[Time < min(death_point)] <- 0

  survival <- exp(-HHazard)
  list(survival=survival)
}

With this in mind, does someone have any other suggestions that I can try? I recently read about the rcpp package, but I am not sure how to replace the for loops with C code. Any suggestions are very welcome.


Solution

  • Function smsurv2 below is faster and the results identical to the question's function.

    Here are some of the changes I have made.

    • subset is slow, indexed subsetting is faster;
    • I removed the computation of coxexp from the first for loop making the loop the code simpler, it will always multiply by coxexp, which can be a vector of 1's, removing the if tests;
    • seq_along is safer than 1:n;
    • removed the assignments to Inf and 0 from the second loop, this can be vectorized outside the loop;
    • as.numeric is never needed, several functions calls are thus eliminated.
    
    smsurv2 <- function(Time, Status, X, beta, w, model){
      death_point <- Time[Status == 1] |> unique() |> sort()
      n <- length(death_point)
      lambda <- numeric(n)
      if(model == 'ph') {
        coxexp <- (exp(beta %*% t(X[, -1]))) |> drop()
      } else if(model == 'aft') {
        coxexp <- rep(1, length(Time))
      }
      for(i in seq_along(death_point)){
        temp <- sum((Time >= death_point[i]) * w * coxexp)
        lambda[i] <- sum(Status * (Time == death_point[i])) / temp
      }
      HHazard <- numeric(length(Time))
      for(i in seq_along(Time)){
        HHazard[i] <- sum((Time[i] >= death_point) * lambda)
      }
      HHazard[ Time > max(death_point) ] <- Inf
      HHazard[ Time < min(death_point) ] <- 0
      survival <- exp(-HHazard)
      list(survival = survival)
    }
    
    
    nr_obs = 50000
    
    Time_input <- rnorm(nr_obs, mean = 100, sd = 36)
    Status_input <- sample(c(0,1), replace=TRUE, size=nr_obs)
    w_input <- Status_input
    
    # Let's suppose there are 9 variables (first column denotes the intercept)
    n_variables <- 9
    X_input <- matrix(rnorm(nr_obs*n_variables),nr_obs)
    X_input <- cbind(Intercept = rep(1, nrow(X_input)), X_input) 
    
    beta_input <- runif(n_variables, min = -1, max = 1)
    model_input <- "ph"
    
    system.time(
      output <- smsurv(Time_input,Status_input,X_input,beta_input,w_input,model_input)
    )
    #>    user  system elapsed 
    #>   25.08    4.95   33.39
    
    system.time(
      output2 <- smsurv2(Time_input,Status_input,X_input,beta_input,w_input,model_input)
    )
    #>    user  system elapsed 
    #>   16.92    1.45   19.86
    
    identical(output, output2)
    #> [1] TRUE
    

    Created on 2023-09-30 with reprex v2.0.2