Search code examples
juliaodedifferential-equationsflux.jldifferentialequations.jl

How can I access the trained parameters of a Neural ODE in Julia?


I'm trying to fit one Neural ODE to a time series usind Julia's DiffEqFlux. Here my code:

u0 = Float32[2.;0]
train_size = 15
tspan_train = (0.0f0,0.75f0)

function trueODEfunc(du,u,p,t)
    true_A = [-0.1 2.0; -2.0 -0.1]
    du .= ((u.^3)'true_A)'
end

t_train = range(tspan_train[1],tspan_train[2],length = train_size)
prob = ODEProblem(trueODEfunc, u0, tspan_train)
ode_data_train = Array(solve(prob, Tsit5(),saveat=t_train))

dudt = Chain(
            Dense(2,50,tanh),
            Dense(50,2))
ps = Flux.params(dudt)
n_ode = NeuralODE(dudt, tspan_train, Tsit5(), saveat = t_train, reltol=1e-7, abstol=1e-9)

**n_ode.p**

function predict_n_ode(p)
    n_ode(u0,p)
end
function loss_n_ode(p)
    pred = predict_n_ode(p)
    loss = sum(abs2, ode_data_train .- pred)
    loss,pred
end

final_p = []
losses = []
cb = function(p,l,pred)
    display(l)
    display(p)
    push!(final_p, p)
    push!(losses,l)
    pl = scatter(t_train, ode_data_train[1,:],label="data")
    scatter!(pl,t_train,pred[1,:],label="prediction")
    display(plot(pl))
end

DiffEqFlux.sciml_train!(loss_n_ode, n_ode.p, ADAM(0.05), cb = cb, maxiters = 100)

**n_ode.p**

The problem is that calling n_ode.p (or Flux.params(dudt)) before and after the train function gives me back the save values. I would have expected to receive the latest updated values from the training. That's why I've created an array to gather all parameter values during the training and then access it to get the updated parameters.

Am I doing something wrong in the code? Does the train function automatically update the parameters? If not how to enforce it?

Thanks in advance!


Solution

  • The result is an object that holds the best parameters. Here's a complete example:

    using DiffEqFlux, OrdinaryDiffEq, Flux, Optim, Plots
    
    u0 = Float32[2.; 0.]
    datasize = 30
    tspan = (0.0f0,1.5f0)
    
    function trueODEfunc(du,u,p,t)
        true_A = [-0.1 2.0; -2.0 -0.1]
        du .= ((u.^3)'true_A)'
    end
    t = range(tspan[1],tspan[2],length=datasize)
    prob = ODEProblem(trueODEfunc,u0,tspan)
    ode_data = Array(solve(prob,Tsit5(),saveat=t))
    
    dudt2 = FastChain((x,p) -> x.^3,
                FastDense(2,50,tanh),
                FastDense(50,2))
    n_ode = NeuralODE(dudt2,tspan,Tsit5(),saveat=t)
    
    function predict_n_ode(p)
      n_ode(u0,p)
    end
    
    function loss_n_ode(p)
        pred = predict_n_ode(p)
        loss = sum(abs2,ode_data .- pred)
        loss,pred
    end
    
    loss_n_ode(n_ode.p) # n_ode.p stores the initial parameters of the neural ODE
    
    cb = function (p,l,pred;doplot=false) #callback function to observe training
      display(l)
      # plot current prediction against data
      if doplot
        pl = scatter(t,ode_data[1,:],label="data")
        scatter!(pl,t,pred[1,:],label="prediction")
        display(plot(pl))
      end
      return false
    end
    
    # Display the ODE with the initial parameter values.
    cb(n_ode.p,loss_n_ode(n_ode.p)...)
    
    res1 = DiffEqFlux.sciml_train(loss_n_ode, n_ode.p, ADAM(0.05), cb = cb, maxiters = 300)
    cb(res1.minimizer,loss_n_ode(res1.minimizer)...;doplot=true)
    res2 = DiffEqFlux.sciml_train(loss_n_ode, res1.minimizer, LBFGS(), cb = cb)
    cb(res2.minimizer,loss_n_ode(res2.minimizer)...;doplot=true)
    
    # result is res2 as an Optim.jl object
    # res2.minimizer are the best parameters
    # res2.minimum is the best loss
    

    At the end, the sciml_train function returns a result object that holds information about the optimization, including the final parameters as .minimizer.