I am trying to use gradient descent to optimize a matrix (a single layer neural network) using a custom loss function. The loss function is a sum of a Gaussian-kernel-embedded maximum mean discrepancy plus the L1 norm of the model weights (matrix elements). It is incredibly slow to train and I don't know why. Each step takes 100x longer than it should. I have gone through the Flux optimization tips but I am still seeing this massive issue in performance.
Julia 1.8 with
[587475ba] Flux v0.13.14
[2913bbd2] StatsBase v0.33.21
[37e2e46d] LinearAlgebra
[9a3f8284] Random
Any help would be appreciated.
Here is the output of my minimum working example (MWE):
❯ julia mwe.jl
[ Info: Train with desired loss function
134.781981 seconds (1.03 G allocations: 115.077 GiB, 22.93% gc time, 18.78% compilation time)
[ Info: Train with MSE loss function
0.277686 seconds (1.70 M allocations: 84.768 MiB, 99.89% compilation time)
[ Info: Benchmark desired loss function
0.227674 seconds (396.55 k allocations: 20.628 MiB, 43.10% compilation time)
and here is my MWE:
using Flux
using Flux.Optimise: Adam, train!
using Flux.Data: DataLoader
using LinearAlgebra
using Random: AbstractRNG, default_rng
using StatsBase: sample
function mmd(x, y; σ=1)
T = eltype(x)
M = length(x)
N = length(y)
mmd = zero(T)
running_total = zero(T)
for i in 1:M, j in 1:M
running_total += gaussian_kernel(x[i], x[j]; σ=σ)
mmd += (running_total / convert(T, M)^convert(T, 2))
running_total = zero(T)
for i in 1:M, j in 1:N
running_total += gaussian_kernel(x[i], y[j]; σ=σ)
mmd -= (convert(T, 2) / convert(T, M * N) * running_total)
running_total = zero(T)
for i in 1:N, j in 1:N
running_total += gaussian_kernel(y[i], y[j]; σ=σ)
mmd += (running_total / convert(T, N)^convert(T, 2))
return mmd
function gaussian_kernel(x, y; σ=1)
return exp(
-one(typeof(x)) / (oftype(x / 1, 2) * oftype(x / 1, σ)^oftype(x / 1, 2)) *
abs(x - y)^oftype(x / 1, 2),
function mmd_loss(x, x̂; σs=[1])
return sum(mmd(x, x̂; σ=σ) for σ in σs)
function generate_data(rng::AbstractRNG, n_samples::T, m::T, n::T, p::T) where {T<:Integer}
# Generate a Gaussian random matrix
H = randn(rng, Float32, n, n_samples) ./ p
# Set all but p indices in each row to zero
for h in eachcol(H)
indices = sample(1:n, n - p; replace=false)
h[indices] .= 0
# Rescale
H /= sqrt(norm(H) / n_samples)
# Compute the label data
U = randn(rng, Float32, m, n_samples)
for u in eachcol(U)
u .= u / norm(u)
return Float32.(H), Float32.(U)
function generate_data(n_samples::T, m::T, n::T, p::T) where {T<:Integer}
return generate_data(default_rng(), n_samples, m, n, p)
invdB(x) = oftype(x / 1, 10)^(x / oftype(x / 1, 10))
function main()
model = Dense(16 => 100, identity; bias=false)
opt_state = Flux.setup(Adam(0.0001f0, (0.9f0, 0.999f0)), model)
λ = one(eltype(model.weight))
H, U = generate_data(20, 100, 16, 3)
dataloader = DataLoader((H, U), batchsize=4)
@info "Train with desired loss function"
@time train!(model, dataloader, opt_state) do m, x, y
this_mmd_loss = mmd_loss(m(x), y; σs = [2, 5, 10, 20, 40, 80])
this_l1_loss = λ * norm(invdB.(model.weight), 1)
this_mmd_loss + this_l1_loss
@info "Train with MSE loss function"
@time train!(model, dataloader, opt_state) do m, x, y
Flux.mse(m(x), y)
@info "Benchmark desired loss function"
@time for (x, y) in dataloader
this_mmd_loss = mmd_loss(model(x), y; σs = [2, 5, 10, 20, 40, 80])
this_l1_loss = λ * norm(invdB.(model.weight), 1)
this_mmd_loss + this_l1_loss
As @mcabbott mentioned, the answer is here https://discourse.julialang.org/t/training-is-500x-slower-than-inference-for-a-custom-loss-function/96283 and the reason it is slow is because of Zygote creating lots of temporary arrays.