Search code examples
callbackjuliaintegrator

Changing problem and vector size with a callback during a OrdinaryDiffEq time integration retains the original solution vector


I am using OrdinaryDiffEq to solve a set of ordinary differential equations. During the time integration I would like to extend the solution vector and change the problem. For that I use a callback of type DiscreteCallback and reinitialize the integrator with the new problem:

using OrdinaryDiffEq

function modify_integrator!(integrator)
    # Define the extended ODE.
    function growth!(du, u, p, t)
        println("growth!")
        du[1] = 0.1*u[1]
        du[2] = 0.2*u[2]
        du[3] = 0.3*u[3]
    end
    u0 = [integrator.u[1]; integrator.u[2]; 1.0]
    tspan = (integrator.t, 10.0)

    # Define the modified problem.
    prob = ODEProblem(growth!, u0, tspan, callback=callbacks)

    # Set up the integrator for the modified problem.
    iter_old = integrator.iter
    naccept_old = integrator.stats.naccept
    integrator = init(prob, integrator.alg, dt=integrator.dt, callback=integrator.opts.callback)
    integrator.iter = iter_old
    integrator.stats.naccept = naccept_old
end


# Define the ODE.
function growth!(du, u, p, t)
    du[1] = 0.1*u[1]
    du[2] = 0.2*u[2]
end

# Define the initial conditions and time span.
u0 = [1.0; 1.0]
tspan = (0.0, 10.0)

# Define the callbacks.
condition(u, t, integrator) = t > 9
cb = DiscreteCallback(condition, modify_integrator!, save_positions=(false,false))
callbacks = CallbackSet(cb)

prob = ODEProblem(growth!, u0, tspan, callback=callbacks)
sol = solve(prob, Tsit5())

The integrator starts with a u-vector with 2 components and the end result should be a vector with 3 components. But my end result is as if the modified integrator got ignored entirely and the original problem is being solved. In my debugging attempts I could verify that the modified growth! function is indeed called, but with no effect on the end solution. It is as if ODEProblem retained the original problem somehow.

The function reinit! seems like it could almost help me out here, but I don't see how to give it a modified problem.


Solution

  • This worked for me. I modified the callback function so that it creates a new integrator with the new problem. I then overwrite the problem of the integrator and resize the problem.

    using OrdinaryDiffEq
    
    # Define the extended ODE.
    function growth_new!(du, u, p, t)
        du[1] = 0.1*u[1]
        du[2] = 0.2*u[2]
        du[3] = 0.3*u[3]
        println("growth_new: t = ", t)
    end
    
    function modify_integrator!(integrator)
        if length(integrator.u) == 2
            u0 = [integrator.u[1]; integrator.u[2]; 1.0]
            tspan = (integrator.t, 60.0)
    
            # Define the modified problem.
            prob = ODEProblem(growth_new!, u0, tspan, callback=callbacks)
    
            # Set up the integrator for the modified problem.
            integrator = init(prob, integrator.alg, dt=integrator.dt, callback=integrator.opts.callback)
            resize!(integrator, 3)
    
            # Copy over some old data and parameters.
            integrator.cache = integrator2.cache
            integrator.dt = integrator2.dt
            integrator.f = integrator2.f
            integrator.iter = integrator2.iter
            integrator.q11 = integrator2.q11
            integrator.qold = integrator2.qold
            integrator.u = integrator2.u
            integrator.uprev = integrator2.uprev
            integrator.uprev2 = integrator2.uprev2
        end
    end
    
    
    # Define the ODE.
    function growth!(du, u, p, t)
        du[1] = 0.1*u[1]
        du[2] = 0.1*u[2]
        println("growth: t = ", t)
    end
    
    # Define the initial conditions and time span.
    u0 = [1.0; 1.0]
    tspan = (0.0, 60.0)
    
    # Define the callbacks.
    condition(u, t, integrator) = (t > 6)# && (t < 8)
    cb = DiscreteCallback(condition, modify_integrator!, save_positions=(false,false))
    callbacks = CallbackSet(cb)
    
    prob = ODEProblem(growth!, u0, tspan, callback=callbacks)
    sol = solve(prob, Tsit5())
    

    This might not be the cleanest solution, but it works and can be used as starting point for something more elegant.