Search code examples
rvariablesparametersode

Problems with ODE solver in R


so let's say that we have an arbitrary system of ODEs in R, which we want to solve, for example a SIR model

  dS <- -beta * I * S
  dI <-  beta * I * S - gamma * I
  dR <-  gamma * I

I want beta and gamma to have time varying parameters, for example

 beta_vector <- seq(0.05, 1, by=0.05)
 gamma_vector <- seq(0.05, 1, by=0.05)

User @Ben Bolker gave me the advice to use beta <- beta_vector[ceiling(time)] inside the gradient function

    sir_1 <- function(beta, gamma, S0, I0, R0, times) {
    require(deSolve) # for the "ode" function
   
     # the differential equations:
     sir_equations <- function(time, variables, parameters) {
         beta <- beta_vector[ceiling(time)]
         gamma <- gamma_vector[ceiling(time)]
         with(as.list(c(variables, parameters)), {
             dS <- -beta * I * S
             dI <-  beta * I * S - gamma * I
             dR <-  gamma * I
             return(list(c(dS, dI, dR)))
           })
       }
     
       # the parameters values:
       parameters_values <- c(beta=beta, gamma = gamma)
       
         # the initial values of variables:
         initial_values <- c(S = S0, I = I0, R = R0)
         
           # solving
           out <- ode(initial_values, times, sir_equations, parameters_values)
           
             # returning the output:
             as.data.frame(out)
        }


sir_1(beta = beta, gamma = gamma, S0 = 99999, I0 = 1, R0 = 0, times = seq(0, 19))

When I execute it it gives me the following error

Error in checkFunc(Func2, times, y, rho) : 
The number of derivatives returned by func() (1) must equal the length of the initial 
 conditions vector (3)

The problem must lay somewhere here:

parameters_values <- c(beta=beta, gamma = gamma)

I have tried to change the paramters_values to a Matrix with two rows (beta in the first, gamma in the second) or two columns, it did not work. What do I have to do in order to make this work?


Solution

  • Your code had several issues, one is that time starts with zero while ceiling needs to start with one, and there was also some confusion with parameter names. In the following, I show one (of several) possible ways that uses approxfuns instead of ceiling. This is more robust, even if ceiling has also some advantages. The parameters are then functions that are passed toodeas a list. An even simpler approach would be to use global variables.

    One additional consideration is whether the time dependent gamma and beta should be linearly interpolated or stepwise. The approxfun function allows both, below I use linear interpolation.

    require(deSolve) # for the "ode" function
    
    beta_vector <- seq(0.05, 1, by=0.05)
    gamma_vector <- seq(0.05, 1, by=0.05)
    
    sir_1 <- function(f_beta, f_gamma, S0, I0, R0, times) {
    
      # the differential equations
      sir_equations <- function(time, variables, parameters) {
        beta  <- f_beta(time)
        gamma <- f_gamma(time)
        with(as.list(variables), {
          dS <- -beta * I * S
          dI <-  beta * I * S - gamma * I
          dR <-  gamma * I
          # include beta and gamma as auxiliary variables for debugging
          return(list(c(dS, dI, dR), beta=beta, gamma=gamma))
        })
      }
      
      # time dependent parameter functions
      parameters_values <- list(
        f_beta  = f_beta,
        f_gamma = f_gamma
      )
      
      # the initial values of variables
      initial_values <- c(S = S0, I = I0, R = R0)
      
      # solving
      # return the deSolve object as is, not a data.frame to ake plotting easier
      out <- ode(initial_values, times, sir_equations, parameters)
    }
    
    times <- seq(0, 19)
    
    # approxfun is a function that returns a function
    f_gamma <- approxfun(x=times, y=seq(0.05, 1, by=0.05), rule=2)
    f_beta <- approxfun(x=times, y=seq(0.05, 1, by=0.05), rule=2)
    
    # check how the approxfun functions work
    f_beta(5)
    
    out <- sir_1(f_beta=f_beta, f_gamma=f_gamma, S0 = 99999, I0 = 1, R0 = 0, times = times)
    
    # plot method of class "deSolve", plots states and auxilliary variables
    plot(out)