Search code examples
juliaflux.jl

How to use a Loss function in Flux.jl


As I was reading through the Flux docs, I see there are a bunch of different loss functions defined for us that we can use. I understand that the loss tells us how far we are away from from target value. But where do I actually make use of the loss function in the training loop?


Solution

  • If you are using the built in train!() function, you can define your loss function and use it during training as follows:

    loss(x, y) = Flux.Losses.mse(m(x), y)
    ps = Flux.params(m)
    
    Flux.train!(loss, ps, data, opt)
    

    where Flux.Losses.mse is using the built in Mean Squared Error function to calculate the distance between m(x) and y. You can read more about loss functions in Flux here: https://fluxml.ai/Flux.jl/stable/training/training/#Loss-Functions