Search code examples
callbackjuliadifferential-equationsdifferentialequations.jl

How can I implement an integration termination callback in DifferentialEquations.jl to solve an ODE?


Need help implementing a integration termination callback in DifferentialEquations.jl.

Greetings,

I have the code

    function height(dh, h, p, t)
    dh[1] = -1*sqrt(h[1])
    end

    h0 = [14]
    tspan = (0.0, 10.0)

    prob = ODEProblem(height, h0, tspan, p)

but when I try solving the ODE with:

    sol = solve(prob)

I get:

"DomainError with -0.019520634518403183: sqrt will only return a complex result if called with a complex argument. Try sqrt(Complex(x))...."

Evidently, during the integration process, h[1] becomes negative valued, thus causing the error.

I tried mitigating the problem with a integration termination callback, since I only want the solution for h(t) >= 0.

Here's my callback code:

    condition(h, t, integrator) = h[1]
    affect!(integrator) = terminate!(integrator)
    cb = ContinuousCallback(condition, affect!)

I thought this would terminate the integration at the timestep when h[1] = 0, but when I then tried:

    sol = solve(prob, callback = cb)

I get the same error. I'm new to using these callback features, so clearly there is something I'm not understanding in implementing them. If you have some idea of what I need to change/amend my code to get it working, I would appreciate your feedback.

Thanks, Gary


Solution

  • Your equation has a solution only for those t that make h(t)>0. If we solve the equation analytically we have: y'=-√y <=> dy/dt=-√y <=> dy/√y =-dt <=> 2√y=-t +C

    with h₀=14 we get 2√14=C. hence -t+C must be greater or equal to 0, i.e. t<=C <=> t <=2√14 =7.483314773547883. Following this FAQ: https://docs.sciml.ai/DiffEqDocs/stable/basics/faq/ I replaced sqrt(h[1]) in the ODE definition by sqrt(max(0, h[1])). With this change the code is as follows:

    using DifferentialEquations, Plots
    function height(dh, h, p, t)
        dh[1] =   -1*sqrt(max(0, h[1]))
    end
    h0 = [14]
    p=[0]
    tspan = (0, 10)
    condition(h, t, integrator)=h[1]
    affect!(integrator) = terminate!(integrator)
    cb = ContinuousCallback(condition, affect!)
    prob = ODEProblem(height, h0, tspan, p);
    sol = solve(prob, callback=cb)
    

    sol.t is:

    0.13020240538639785
     1.0122592548078624
     2.492004954051219
     3.9874743468989706
     5.525543090227709
     6.22154035990045
     6.677342684567405
     6.9977821858188936
     7.176766608936562
     7.281444926891483
     7.36468081569681
     7.415386846800273
     7.449319963724896
     7.47183932499115
     7.479193894248527
     7.479193894248527
    

    i.e. the last t is close to 2sqrt(14). One more time step will exceed 2sqrt(14). Plotting sol, the plot will display an extended sol to [0,10]:

    plt1 = plot(sol.t, getindex.(sol.u, 1), xlabel="Time (t)", ylabel="y(t)", 
              framestyle=:box, size=(400,300), legend=false)
    

    but with:

    plt2 = plot(sol.t, getindex.(sol.u, 1), xlabel="Time (t)", ylabel="y(t)", 
              framestyle=:box, size=(400,300), legend=false)
    

    we get the solution of the given ODE