Search code examples
juliaflux

Bug when training a custom model using Flux


I am pretty new to Julia and Flux. I am trying to build a simple neural network but using an attention layer. I wrote the code as follows, which works fine in the inference(feed-forward) mode:

using Flux

struct Attention
    W
    v
end

Attention(vehile_embedding_dim::Integer) = Attention(
    Dense(vehile_embedding_dim => vehile_embedding_dim, tanh),
    Dense(vehile_embedding_dim, 1, bias=false, init=Flux.zeros32)
)

function (a::Attention)(inputs)
    alphas = [a.v(e) for e in a.W.(inputs)]
    alphas = sigmoid.(alphas)
    output = sum([alpha.*input for (alpha, input) in zip(alphas, inputs)])
    return output
end

Flux.@functor Attention

struct AttentionNet 
    embedding
    attention
    fc_output
    vehicle_num::Integer
    vehicle_dim::Integer
end

AttentionNet(vehicle_num::Integer, vehicle_dim::Integer, embedding_dim::Integer) = AttentionNet(
    Dense(vehicle_dim+1 => embedding_dim, relu),
    Attention(embedding_dim),
    Dense(1+embedding_dim => 1),
    vehicle_num,
    vehicle_dim
)

function (a_net::AttentionNet)(x)
    time_idx = x[[1], :]
    vehicle_states = [x[2+a_net.vehicle_dim*(i-1):2+a_net.vehicle_dim*i-1, :] for i in 1:a_net.vehicle_num]
    vehicle_states = [vcat(time_idx, vehicle_state) for vehicle_state in vehicle_states]

    vehicle_embeddings = a_net.embedding.(vehicle_states)
    attention_output = a_net.attention(vehicle_embeddings)
    
    x = a_net.fc_output(vcat(time_idx, attention_output))
    return x
end

Flux.@functor AttentionNet
Flux.trainable(a_net::AttentionNet) = (a_net.embedding, a_net.attention, a_net.fc_output,)



fake_inputs = rand(22, 640)
fake_outputs = rand(1, 640)
a_net = AttentionNet(3, 7, 64)|> gpu
opt = Adam(.01)
opt_state = Flux.setup(opt, a_net)

data = Flux.DataLoader((fake_inputs, fake_outputs)|>gpu, batchsize=32, shuffle=true)

Flux.train!(a_net, data, opt_state) do m, x, y
    Flux.mse(m(x), y)
end

But when I trained it, I got the following error message and a warning:

┌ Warning: trainable(x) should now return a NamedTuple with the field names, not a Tuple
└ @ Optimisers C:\Users\Herr LU\.julia\packages\Optimisers\SoKJO\src\interface.jl:164
ERROR: MethodError: no method matching +(::Base.RefValue{Any}, ::NamedTuple{(:contents,), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})
Closest candidates are:
  +(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
  +(::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}, ::Any) at C:\Users\Herr LU\.julia\packages\InitialValues\OWP8V\src\InitialValues.jl:154
  +(::ChainRulesCore.Tangent{P}, ::P) where P at C:\Users\Herr LU\.julia\packages\ChainRulesCore\C73ay\src\tangent_arithmetic.jl:146
  ...
Stacktrace:
  [1] accum(x::Base.RefValue{Any}, y::NamedTuple{(:contents,), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\lib\lib.jl:17
  [2] accum(x::Base.RefValue{Any}, y::NamedTuple{(:contents,), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, zs::Base.RefValue{Any}) (repeats 2 times)
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\lib\lib.jl:22
  [3] Pullback
    @ e:\Master Thesis\lu_jizhou\toy exmaple\dqn_model.jl:39 [inlined]
  [4] (::typeof(∂(λ)))(Δ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
  [5] Pullback
    @ e:\Master Thesis\lu_jizhou\toy exmaple\dqn_model.jl:62 [inlined]
  [6] #208
    @ C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\lib\lib.jl:206 [inlined]
  [7] #2066#back
    @ C:\Users\Herr LU\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
  [8] Pullback
    @ C:\Users\Herr LU\.julia\packages\Flux\ZdbJr\src\train.jl:102 [inlined]
  [9] (::typeof(∂(λ)))(Δ::Float32)
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
 [10] (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float32)
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:45
 [11] withgradient(f::Function, args::AttentionNet)
    @ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:133
 [12] macro expansion
    @ C:\Users\Herr LU\.julia\packages\Flux\ZdbJr\src\train.jl:102 [inlined]
 [13] macro expansion
    @ C:\Users\Herr LU\.julia\packages\ProgressLogging\6KXlp\src\ProgressLogging.jl:328 [inlined]
 [14] train!(loss::Function, model::AttentionNet, data::MLUtils.DataLoader{Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Random._GLOBAL_RNG, Val{nothing}}, opt::Named
Tuple{(:embedding, :attention, :fc_output, :vehicle_num, :vehicle_dim), Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArr
ay{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64
, Float64}}}, Tuple{}}}, NamedTuple{(:W, :v), Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.De
viceBuffer}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}}}, N
amedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}, Tupl
e{}}}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Opt
imisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}}}, Tuple{}, Tuple{}}}; cb::Nothing)
    @ Flux.Train C:\Users\Herr LU\.julia\packages\Flux\ZdbJr\src\train.jl:100
 [15] train!(loss::Function, model::AttentionNet, data::MLUtils.DataLoader{Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Random._GLOBAL_RNG, Val{nothing}}, opt::Named
Tuple{(:embedding, :attention, :fc_output, :vehicle_num, :vehicle_dim), Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArr
ay{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64
, Float64}}}, Tuple{}}}, NamedTuple{(:W, :v), Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.De
viceBuffer}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}}}, N
amedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}, Tupl
e{}}}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Opt
imisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}}}, Tuple{}, Tuple{}}})
    @ Flux.Train C:\Users\Herr LU\.julia\packages\Flux\ZdbJr\src\train.jl:97
 [16] top-level scope
    @ e:\Master Thesis\lu_jizhou\toy exmaple\dqn_model.jl:61

I followed the instruction from the official tutorial on custom layers, but it doesn’t specify how to get custom layers properly trained. Could someone help me out?


Solution

  • For anyone who is interested, this problem is well solved by @ToucheSir on this GitHub thread.