Search code examples
rnon-linear-regressionmse

Use nlm function to minimize the MSE of a nonlinear regression function


I have a dataset containing weights of chicken as a function of time. I would like to predict the weight using the following gompertz equation,

gompertz equation

where I set p0 = 40 and alpha = 2500. I need to estimate the parameter beta. To do that, I tried to minimize the MSE between the true weights and predicted weights using nlm function from R.

gompertz_model <- function(times, beta) {
  p0 <- 40
  alpha <- 2500
  
  alpha * exp(log(p0 / alpha) * exp(-beta * times))
}
mean_squared_error <- function(param, model, times) {
  true_weights <- chicken$weight
  predicted_weights <- model(times, param)

  return(mean((true_weights - predicted_weights)^2))
}

I tried several values from 0 to 1 with a step of 0.1 but I feel like I get the wrong answer. It gives me an estimate of beta = -999. But the MSE is really big with that value. Am I using nlm the right way ?

find_best_param <- function(model, params) {
  res <- nlm(mean_squared_error, p = params, model, chicken$time)
  return(res$estimate)
}

params <- seq(from = 0, to = 1, by = 0.1)
optim_B <- find_best_param(gompertz_model, 1)

# = -999 => gives an MSE of 19736.45
optim_B 

Solution

  • I can see two errors in your analysis: 1) you should pass a single number as a parameter which will serve as a starting point for nlm; 2) you're actually passing only 1 as starting parameter so the results are constant.

    I tried to reproduce your analysis using the nlme dataset ChickWeight:

    library(nlme)
    gompertz_model <- function(times, beta) {
      p0 <- 40
      alpha <- 2500
      alpha * exp(log(p0 / alpha) * exp(-beta * times))
    }
    
    
    mean_squared_error <- function(param, model, times) {
      true_weights <- ChickWeight$weight
      predicted_weights <- model(times, param)
      return(mean((true_weights - predicted_weights)^2))
    }
    
    find_best_param <- function(model, params) {
        res <- nlm(mean_squared_error,
                   p = params,
                   model,
                   ChickWeight$Time)
      return(res$estimate)
    }
    

    Then I tried to evaluate sensitivity of analysis using several starting parameters:

    library(purrr)
    library(dplyr)
    params <- c(seq(from = 0, to = 0.1, by = 0.01),
                seq(from = 0.1, to = 1, by = 0.1))
    data.frame(opt = map_dbl(params,
                             ~find_best_param(gompertz_model, .x)),
               start = params) |>
        mutate(mse = map_dbl(opt,
                             mean_squared_error,
                             gompertz_model,
                             ChickWeight$Time))
    
    ##>              opt start       mse
    ##> 1     0.02598375  0.00  1486.403
    ##> 2     0.02598375  0.01  1486.403
    ##> 3     0.02598375  0.02  1486.403
    ##> 4     0.02598375  0.03  1486.403
    ##> 5     0.02598411  0.04  1486.403
    ##> 6   -39.95910231  0.05 19736.448
    ##> 7   -92.14594360  0.06 19736.448
    ##> 8   -92.14066069  0.07 19736.448
    ##> 9  -211.26161107  0.08 19736.448
    ##> 10 -211.25607875  0.09 19736.448
    ##> 11 -211.25105141  0.10 19736.448
    ##> 12 -211.25105141  0.10 19736.448
    ##> 13 -999.80000000  0.20 19736.448
    ##> 14 -999.70000000  0.30 19736.448
    ##> 15 -999.60000000  0.40 19736.448
    ##> 16 -999.50000000  0.50 19736.448
    ##> 17 -999.40000000  0.60 19736.448
    ##> 18 -999.30000000  0.70 19736.448
    ##> 19 -999.20000000  0.80 19736.448
    ##> 20 -999.10000000  0.90 19736.448
    ##> 21 -999.00000000  1.00 19736.448
    

    The results suggest that (in this dataset) the optimization is very sensitive to starting parameters and a good range, in this case, is [0 - 0.04]).

    Finally, a plot showing results (fitted line is in black):

    library(ggplot2)
    
    optim_B <- find_best_param(gompertz_model, 0.)
    optim_B
    
    #> [1] 0.02598375
    
    ChickWeight |>
        mutate(pred = gompertz_model(Time, optim_B)) |>
        ggplot(aes(x = Time)) +
        geom_line(aes(y = weight, color  = Diet, group = Chick)) +
        geom_line(aes(y = pred)) +
        theme_minimal()
    

    chickens