Search code examples
julialossflux.jlunet-neural-network

Is there a way to print loss from Flux.train?


I'm trying to train a UNet in Julia with the help of Flux.

Flux.train!(loss, Flux.params(model), train_data_loader, opt)
            batch_loss = loss(train_data, train_targets)

where the loss is

logitcrossentropy

and train_data_loader is

train_data_loader = DataLoader((train_data |> device, train_targets |> device), batchsize=batch_size, shuffle=true)

I dont understand how to take the loss from Flux.train out for printing loss (is that validation loss?). Evalcb will also trigger a call to calculate loss, so its not different. I was to skip extra calculation. So What I did is call the loss function again and store it in a variable then print it per batch. Is there a way to print loss from Flux.train() instead of calling loss again?


Solution

  • Adding to @Dan's answer, you can also augment your loss function with logging on the fly using the do syntax:

    using ChainRules
    
    loss_history = Float32[]
    Flux.train!(Flux.params(model), train_data_loader, opt) do x, y
        err = loss(x, y)
        ChainRules.ignore_derivatives() do
            push!(loss_history, err)
        end
        return err
    end