Need help implementing a integration termination callback in DifferentialEquations.jl.
Greetings,
I have the code
function height(dh, h, p, t)
dh[1] = -1*sqrt(h[1])
end
h0 = [14]
tspan = (0.0, 10.0)
prob = ODEProblem(height, h0, tspan, p)
but when I try solving the ODE with:
sol = solve(prob)
I get:
"DomainError with -0.019520634518403183: sqrt will only return a complex result if called with a complex argument. Try sqrt(Complex(x))...."
Evidently, during the integration process, h[1] becomes negative valued, thus causing the error.
I tried mitigating the problem with a integration termination callback, since I only want the solution for h(t) >= 0.
Here's my callback code:
condition(h, t, integrator) = h[1]
affect!(integrator) = terminate!(integrator)
cb = ContinuousCallback(condition, affect!)
I thought this would terminate the integration at the timestep when h[1] = 0, but when I then tried:
sol = solve(prob, callback = cb)
I get the same error. I'm new to using these callback features, so clearly there is something I'm not understanding in implementing them. If you have some idea of what I need to change/amend my code to get it working, I would appreciate your feedback.
Thanks, Gary
Your equation has a solution only for those t that make h(t)>0. If we solve the equation analytically we have: y'=-√y <=> dy/dt=-√y <=> dy/√y =-dt <=> 2√y=-t +C
with h₀=14 we get 2√14=C. hence -t+C must be greater or equal to 0, i.e. t<=C <=> t <=2√14 =7.483314773547883. Following this FAQ: https://docs.sciml.ai/DiffEqDocs/stable/basics/faq/ I replaced sqrt(h[1]) in the ODE definition by sqrt(max(0, h[1])). With this change the code is as follows:
using DifferentialEquations, Plots
function height(dh, h, p, t)
dh[1] = -1*sqrt(max(0, h[1]))
end
h0 = [14]
p=[0]
tspan = (0, 10)
condition(h, t, integrator)=h[1]
affect!(integrator) = terminate!(integrator)
cb = ContinuousCallback(condition, affect!)
prob = ODEProblem(height, h0, tspan, p);
sol = solve(prob, callback=cb)
sol.t is:
0.13020240538639785
1.0122592548078624
2.492004954051219
3.9874743468989706
5.525543090227709
6.22154035990045
6.677342684567405
6.9977821858188936
7.176766608936562
7.281444926891483
7.36468081569681
7.415386846800273
7.449319963724896
7.47183932499115
7.479193894248527
7.479193894248527
i.e. the last t is close to 2sqrt(14). One more time step will exceed 2sqrt(14). Plotting sol, the plot will display an extended sol to [0,10]:
plt1 = plot(sol.t, getindex.(sol.u, 1), xlabel="Time (t)", ylabel="y(t)",
framestyle=:box, size=(400,300), legend=false)
but with:
plt2 = plot(sol.t, getindex.(sol.u, 1), xlabel="Time (t)", ylabel="y(t)",
framestyle=:box, size=(400,300), legend=false)
we get the solution of the given ODE