Search code examples
performanceoptimizationjuliafluxbackpropagation

Training is 500x slower than inference for a custom loss function


I am trying to use gradient descent to optimize a matrix (a single layer neural network) using a custom loss function. The loss function is a sum of a Gaussian-kernel-embedded maximum mean discrepancy plus the L1 norm of the model weights (matrix elements). It is incredibly slow to train and I don't know why. Each step takes 100x longer than it should. I have gone through the Flux optimization tips but I am still seeing this massive issue in performance.

Julia 1.8 with

  [587475ba] Flux v0.13.14
  [2913bbd2] StatsBase v0.33.21
  [37e2e46d] LinearAlgebra
  [9a3f8284] Random

Any help would be appreciated.


Here is the output of my minimum working example (MWE):

❯ julia mwe.jl
[ Info: Train with desired loss function
134.781981 seconds (1.03 G allocations: 115.077 GiB, 22.93% gc time, 18.78% compilation time)
[ Info: Train with MSE loss function
  0.277686 seconds (1.70 M allocations: 84.768 MiB, 99.89% compilation time)
[ Info: Benchmark desired loss function
  0.227674 seconds (396.55 k allocations: 20.628 MiB, 43.10% compilation time)

and here is my MWE:

using Flux
using Flux.Optimise: Adam, train!
using Flux.Data: DataLoader
using LinearAlgebra
using Random: AbstractRNG, default_rng
using StatsBase: sample

function mmd(x, y; σ=1)
    T = eltype(x)
    M = length(x)
    N = length(y)

    mmd = zero(T)
    running_total = zero(T)

    for i in 1:M, j in 1:M
        running_total += gaussian_kernel(x[i], x[j]; σ=σ)
    end
    mmd += (running_total / convert(T, M)^convert(T, 2))

    running_total = zero(T)
    for i in 1:M, j in 1:N
        running_total += gaussian_kernel(x[i], y[j]; σ=σ)
    end
    mmd -= (convert(T, 2) / convert(T, M * N) * running_total)

    running_total = zero(T)
    for i in 1:N, j in 1:N
        running_total += gaussian_kernel(y[i], y[j]; σ=σ)
    end
    mmd += (running_total / convert(T, N)^convert(T, 2))

    return mmd
end

function gaussian_kernel(x, y; σ=1)
    return exp(
        -one(typeof(x)) / (oftype(x / 1, 2) * oftype(x / 1, σ)^oftype(x / 1, 2)) *
        abs(x - y)^oftype(x / 1, 2),
    )
end

function mmd_loss(x, x̂; σs=[1])
    return sum(mmd(x, x̂; σ=σ) for σ in σs)
end

function generate_data(rng::AbstractRNG, n_samples::T, m::T, n::T, p::T) where {T<:Integer}
    # Generate a Gaussian random matrix
    H = randn(rng, Float32, n, n_samples) ./ p
    # Set all but p indices in each row to zero
    for h in eachcol(H)
        indices = sample(1:n, n - p; replace=false)
        h[indices] .= 0
    end
    # Rescale
    H /= sqrt(norm(H) / n_samples)

    # Compute the label data
    U = randn(rng, Float32, m, n_samples)
    for u in eachcol(U)
        u .= u / norm(u)
    end
    return Float32.(H), Float32.(U)
end
function generate_data(n_samples::T, m::T, n::T, p::T) where {T<:Integer}
    return generate_data(default_rng(), n_samples, m, n, p)
end

invdB(x) = oftype(x / 1, 10)^(x / oftype(x / 1, 10))

function main()
    model = Dense(16 => 100, identity; bias=false)
    opt_state = Flux.setup(Adam(0.0001f0, (0.9f0, 0.999f0)), model)
    λ = one(eltype(model.weight))

    H, U = generate_data(20, 100, 16, 3)
    dataloader = DataLoader((H, U), batchsize=4)

    @info "Train with desired loss function"
    @time train!(model, dataloader, opt_state) do m, x, y
        this_mmd_loss = mmd_loss(m(x), y; σs = [2, 5, 10, 20, 40, 80])
        this_l1_loss = λ * norm(invdB.(model.weight), 1)
        this_mmd_loss + this_l1_loss
    end

    @info "Train with MSE loss function"
    @time train!(model, dataloader, opt_state) do m, x, y
        Flux.mse(m(x), y)
    end

    @info "Benchmark desired loss function"
    @time for (x, y) in dataloader
        this_mmd_loss = mmd_loss(model(x), y; σs = [2, 5, 10, 20, 40, 80])
        this_l1_loss = λ * norm(invdB.(model.weight), 1)
        this_mmd_loss + this_l1_loss
    end
end

main()

Optimize a one-layer "neural network" using a simple custom loss function. The loss is efficient in the forward pass, but when training, it takes 100x longer than expected. I went through the Flux documentation and made sure I do not have type instability or unexpected type promotion (e.g., from Float32 to Float64). I should expect that training adds only minimal overhead.


Solution

  • As @mcabbott mentioned, the answer is here https://discourse.julialang.org/t/training-is-500x-slower-than-inference-for-a-custom-loss-function/96283 and the reason it is slow is because of Zygote creating lots of temporary arrays.