I've been trying to fit a system of differential equations to some data I have and there are 18 parameters to fit, however ideally some of these parameters should be zero/go to zero. While googling this one thing I came across was building DE layers into neural networks, and I have found a few Github repos with Julia code examples, however I am new to both Julia and Neural ODEs. In particular, I have been modifying the code from this example:
https://computationalmindset.com/en/neural-networks/experiments-with-neural-odes-in-julia.html
Differences: I have a system of 3 DEs, not 2, I have 18 parameters, and I import two CSVs with data to fit that instead of generate a toy dataset to fit.
My dilemma: while goolging I came across LASSO/L1 regularization and hope that by adding an L1 penalty to the cost function, that I can "zero out" some of the parameters. The problem is I don't understand how to modify the cost function to incorporate it. My loss function right now is just
function loss_func()
pred = net()
sum(abs2, truth[1] .- pred[1,:]) +
sum(abs2, truth[2] .- pred[2,:]) +
sum(abs2, truth[3] .- pred[3,:])
end
but I would like to incorporate the L1 penalty into this. For L1 regression, I came across the equation for the cost function: J′(θ;X,y) = J(θ;X,y)+aΩ(θ)
, where "where θ
denotes the trainable parameters, X
the input... y
[the] target labels. a
is a hyperparameter that weights the contribution of the norm penalty" and for L1 regularization, the penalty is Ω(θ) = ∣∣w∣∣ = ∑∣w∣
(source: https://theaisummer.com/regularization/). I understand the first-term on the RHS is the loss J(θ;X,y)
and is what I already have, that a
is a hyperparameter that I choose and could be 0.001, 0.1, 1, 100000000, etc., and that the L1 penalty is the sum of the absolute value of the parameters. What I don't understand is how I add the a∑∣w∣
term to my current function - I want to edit it to be something like so:
function cost_func(lambda)
pred = net()
penalty(lambda) = lambda * (sum(abs(param[1])) +
sum(abs(param[2])) +
sum(abs(param[3]))
)
sum(abs2, truth[1] .- pred[1,:]) +
sum(abs2, truth[2] .- pred[2,:]) +
sum(abs2, truth[3] .- pred[3,:]) +
penalty(lambda)
end
where param[1], param[2], param[3]
refers to the parameters for DEs u[1], u[2], u[3]
that I'm trying to learn. I don't know if this logic is correct though or the proper way to implement it, and also I don't know how/where I would access the learned parameters. I suspect that the answer may lie somewhere in this chunk of code
callback_func = function ()
loss_value = loss_func()
println("Loss: ", loss_value)
end
fparams = Flux.params(p)
Flux.train!(loss_func, fparams, data, optimizer, cb = callback_func);
but I don't know for certain or even how to use it, if it were the answer.
I've been messing with this, and looking at some other NODE implementations (this one in particular) and have adjusted my cost function so that it is:
function cost_fnct(param)
prob = ODEProblem(model, u0, tspan, param)
prediction = Array(concrete_solve(prob, Tsit5(), p = param, saveat = trange))
loss = Flux.mae(prediction, data)
penalty = sum(abs, param)
loss + lambda*penalty
end;
where lambda
is the tuning parameter, and using the definition that the L1 penalty is the sum of the absolute value of the parameters. Then, for training:
lambda = 0.01
resinit = DiffEqFlux.sciml_train(cost_fnct, p, ADAM(), maxiters = 3000)
res = DiffEqFlux.sciml_train(cost_fnct, resinit.minimizer, BFGS(initial_stepnorm = 1e-5))
where p
is initial just my parameter "guesses", i.e., a vector of ones with the same length as the number of parameters I am attempting to fit.
If you're looking at the first link I had in the original post (here), you can redefine the loss function to add this penalty term and then define lambda
before the callback function and subsequent training:
lambda = 0.01
callback_func = function ()
loss_value = cost_fnct()
println("Loss: ", loss_value)
println("\nLearned parameters: ", p)
end
fparams = Flux.params(p)
Flux.train!(cost_fnct, fparams, data, optimizer, cb = callback_func);
None of this, of course, includes any sort of cross-validation and tuning parameter optimization! I'll go ahead and accept my response to my question because it's my understanding that unanswered questions get pushed to encourage answers, and I want to avoid clogging the tag, but if anyone has a different solution, or wants to comment, please feel free to go ahead and do so.