I am pretty new to Julia and Flux. I am trying to build a simple neural network but using an attention layer. I wrote the code as follows, which works fine in the inference(feed-forward) mode:
using Flux
struct Attention
W
v
end
Attention(vehile_embedding_dim::Integer) = Attention(
Dense(vehile_embedding_dim => vehile_embedding_dim, tanh),
Dense(vehile_embedding_dim, 1, bias=false, init=Flux.zeros32)
)
function (a::Attention)(inputs)
alphas = [a.v(e) for e in a.W.(inputs)]
alphas = sigmoid.(alphas)
output = sum([alpha.*input for (alpha, input) in zip(alphas, inputs)])
return output
end
Flux.@functor Attention
struct AttentionNet
embedding
attention
fc_output
vehicle_num::Integer
vehicle_dim::Integer
end
AttentionNet(vehicle_num::Integer, vehicle_dim::Integer, embedding_dim::Integer) = AttentionNet(
Dense(vehicle_dim+1 => embedding_dim, relu),
Attention(embedding_dim),
Dense(1+embedding_dim => 1),
vehicle_num,
vehicle_dim
)
function (a_net::AttentionNet)(x)
time_idx = x[[1], :]
vehicle_states = [x[2+a_net.vehicle_dim*(i-1):2+a_net.vehicle_dim*i-1, :] for i in 1:a_net.vehicle_num]
vehicle_states = [vcat(time_idx, vehicle_state) for vehicle_state in vehicle_states]
vehicle_embeddings = a_net.embedding.(vehicle_states)
attention_output = a_net.attention(vehicle_embeddings)
x = a_net.fc_output(vcat(time_idx, attention_output))
return x
end
Flux.@functor AttentionNet
Flux.trainable(a_net::AttentionNet) = (a_net.embedding, a_net.attention, a_net.fc_output,)
fake_inputs = rand(22, 640)
fake_outputs = rand(1, 640)
a_net = AttentionNet(3, 7, 64)|> gpu
opt = Adam(.01)
opt_state = Flux.setup(opt, a_net)
data = Flux.DataLoader((fake_inputs, fake_outputs)|>gpu, batchsize=32, shuffle=true)
Flux.train!(a_net, data, opt_state) do m, x, y
Flux.mse(m(x), y)
end
But when I trained it, I got the following error message and a warning:
┌ Warning: trainable(x) should now return a NamedTuple with the field names, not a Tuple
└ @ Optimisers C:\Users\Herr LU\.julia\packages\Optimisers\SoKJO\src\interface.jl:164
ERROR: MethodError: no method matching +(::Base.RefValue{Any}, ::NamedTuple{(:contents,), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})
Closest candidates are:
+(::Any, ::Any, ::Any, ::Any...) at operators.jl:591
+(::Union{InitialValues.NonspecificInitialValue, InitialValues.SpecificInitialValue{typeof(+)}}, ::Any) at C:\Users\Herr LU\.julia\packages\InitialValues\OWP8V\src\InitialValues.jl:154
+(::ChainRulesCore.Tangent{P}, ::P) where P at C:\Users\Herr LU\.julia\packages\ChainRulesCore\C73ay\src\tangent_arithmetic.jl:146
...
Stacktrace:
[1] accum(x::Base.RefValue{Any}, y::NamedTuple{(:contents,), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}})
@ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\lib\lib.jl:17
[2] accum(x::Base.RefValue{Any}, y::NamedTuple{(:contents,), Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}}, zs::Base.RefValue{Any}) (repeats 2 times)
@ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\lib\lib.jl:22
[3] Pullback
@ e:\Master Thesis\lu_jizhou\toy exmaple\dqn_model.jl:39 [inlined]
[4] (::typeof(∂(λ)))(Δ::CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer})
@ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[5] Pullback
@ e:\Master Thesis\lu_jizhou\toy exmaple\dqn_model.jl:62 [inlined]
[6] #208
@ C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\lib\lib.jl:206 [inlined]
[7] #2066#back
@ C:\Users\Herr LU\.julia\packages\ZygoteRules\AIbCs\src\adjoint.jl:67 [inlined]
[8] Pullback
@ C:\Users\Herr LU\.julia\packages\Flux\ZdbJr\src\train.jl:102 [inlined]
[9] (::typeof(∂(λ)))(Δ::Float32)
@ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\compiler\interface2.jl:0
[10] (::Zygote.var"#60#61"{typeof(∂(λ))})(Δ::Float32)
@ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:45
[11] withgradient(f::Function, args::AttentionNet)
@ Zygote C:\Users\Herr LU\.julia\packages\Zygote\SmJK6\src\compiler\interface.jl:133
[12] macro expansion
@ C:\Users\Herr LU\.julia\packages\Flux\ZdbJr\src\train.jl:102 [inlined]
[13] macro expansion
@ C:\Users\Herr LU\.julia\packages\ProgressLogging\6KXlp\src\ProgressLogging.jl:328 [inlined]
[14] train!(loss::Function, model::AttentionNet, data::MLUtils.DataLoader{Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Random._GLOBAL_RNG, Val{nothing}}, opt::Named
Tuple{(:embedding, :attention, :fc_output, :vehicle_num, :vehicle_dim), Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArr
ay{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64
, Float64}}}, Tuple{}}}, NamedTuple{(:W, :v), Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.De
viceBuffer}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}}}, N
amedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}, Tupl
e{}}}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Opt
imisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}}}, Tuple{}, Tuple{}}}; cb::Nothing)
@ Flux.Train C:\Users\Herr LU\.julia\packages\Flux\ZdbJr\src\train.jl:100
[15] train!(loss::Function, model::AttentionNet, data::MLUtils.DataLoader{Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}}, Random._GLOBAL_RNG, Val{nothing}}, opt::Named
Tuple{(:embedding, :attention, :fc_output, :vehicle_num, :vehicle_dim), Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArr
ay{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64
, Float64}}}, Tuple{}}}, NamedTuple{(:W, :v), Tuple{NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.De
viceBuffer}, Tuple{Float64, Float64}}}, Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}}}, N
amedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}, Tupl
e{}}}}}, NamedTuple{(:weight, :bias, :σ), Tuple{Optimisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Opt
imisers.Leaf{Optimisers.Adam{Float64}, Tuple{CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, CUDA.CuArray{Float32, 1, CUDA.Mem.DeviceBuffer}, Tuple{Float64, Float64}}}, Tuple{}}}, Tuple{}, Tuple{}}})
@ Flux.Train C:\Users\Herr LU\.julia\packages\Flux\ZdbJr\src\train.jl:97
[16] top-level scope
@ e:\Master Thesis\lu_jizhou\toy exmaple\dqn_model.jl:61
I followed the instruction from the official tutorial on custom layers, but it doesn’t specify how to get custom layers properly trained. Could someone help me out?
For anyone who is interested, this problem is well solved by @ToucheSir on this GitHub thread.