Search code examples
machine-learningjuliaregressionloss-functionlasso-regression

How to add an L1 penalty to the loss function for Neural ODEs?


I've been trying to fit a system of differential equations to some data I have and there are 18 parameters to fit, however ideally some of these parameters should be zero/go to zero. While googling this one thing I came across was building DE layers into neural networks, and I have found a few Github repos with Julia code examples, however I am new to both Julia and Neural ODEs. In particular, I have been modifying the code from this example:

https://computationalmindset.com/en/neural-networks/experiments-with-neural-odes-in-julia.html

Differences: I have a system of 3 DEs, not 2, I have 18 parameters, and I import two CSVs with data to fit that instead of generate a toy dataset to fit.

My dilemma: while goolging I came across LASSO/L1 regularization and hope that by adding an L1 penalty to the cost function, that I can "zero out" some of the parameters. The problem is I don't understand how to modify the cost function to incorporate it. My loss function right now is just

function loss_func()
 pred = net()

 sum(abs2, truth[1] .- pred[1,:]) +
 sum(abs2, truth[2] .- pred[2,:]) +
 sum(abs2, truth[3] .- pred[3,:])
end

but I would like to incorporate the L1 penalty into this. For L1 regression, I came across the equation for the cost function: J′(θ;X,y) = J(θ;X,y)+aΩ(θ), where "where θ denotes the trainable parameters, X the input... y [the] target labels. a is a hyperparameter that weights the contribution of the norm penalty" and for L1 regularization, the penalty is Ω(θ) = ∣∣w∣∣ = ∑∣w∣ (source: https://theaisummer.com/regularization/). I understand the first-term on the RHS is the loss J(θ;X,y) and is what I already have, that a is a hyperparameter that I choose and could be 0.001, 0.1, 1, 100000000, etc., and that the L1 penalty is the sum of the absolute value of the parameters. What I don't understand is how I add the a∑∣w∣ term to my current function - I want to edit it to be something like so:

function cost_func(lambda)
 pred = net()
 penalty(lambda) = lambda * (sum(abs(param[1])) + 
                             sum(abs(param[2])) + 
                             sum(abs(param[3]))
                            )
 sum(abs2, truth[1] .- pred[1,:]) +
 sum(abs2, truth[2] .- pred[2,:]) +
 sum(abs2, truth[3] .- pred[3,:]) +
 penalty(lambda)
end

where param[1], param[2], param[3] refers to the parameters for DEs u[1], u[2], u[3] that I'm trying to learn. I don't know if this logic is correct though or the proper way to implement it, and also I don't know how/where I would access the learned parameters. I suspect that the answer may lie somewhere in this chunk of code

callback_func = function ()
 loss_value = loss_func()
 println("Loss: ", loss_value)
end
fparams = Flux.params(p)
Flux.train!(loss_func, fparams, data, optimizer, cb = callback_func);

but I don't know for certain or even how to use it, if it were the answer.


Solution

  • I've been messing with this, and looking at some other NODE implementations (this one in particular) and have adjusted my cost function so that it is:

    function cost_fnct(param)
       prob = ODEProblem(model, u0, tspan, param)
       prediction = Array(concrete_solve(prob, Tsit5(), p = param, saveat = trange))
    
       loss = Flux.mae(prediction, data)
       penalty = sum(abs, param)
       loss + lambda*penalty
     end;
    

    where lambda is the tuning parameter, and using the definition that the L1 penalty is the sum of the absolute value of the parameters. Then, for training:

    lambda = 0.01
    resinit = DiffEqFlux.sciml_train(cost_fnct, p, ADAM(), maxiters = 3000)
    res = DiffEqFlux.sciml_train(cost_fnct, resinit.minimizer, BFGS(initial_stepnorm = 1e-5))
    

    where p is initial just my parameter "guesses", i.e., a vector of ones with the same length as the number of parameters I am attempting to fit.

    If you're looking at the first link I had in the original post (here), you can redefine the loss function to add this penalty term and then define lambda before the callback function and subsequent training:

    lambda = 0.01
    callback_func = function ()
         loss_value = cost_fnct()
         println("Loss: ", loss_value)
         println("\nLearned parameters: ", p)
    end
    fparams = Flux.params(p)
    Flux.train!(cost_fnct, fparams, data, optimizer, cb = callback_func);
    

    None of this, of course, includes any sort of cross-validation and tuning parameter optimization! I'll go ahead and accept my response to my question because it's my understanding that unanswered questions get pushed to encourage answers, and I want to avoid clogging the tag, but if anyone has a different solution, or wants to comment, please feel free to go ahead and do so.