Search code examples
machine-learningjuliaflux.jl

Stop tracking Arrays in Flux (Julia)


I am currently trying to implement a batch update in Flux for Julia.

During my calculations, I obtain a batch of scalars by repeatedly doing

δ = Gt - model(St)[1]
push!(deltas,δ)

where model is a Neural Network

global model= Chain(
    Dense(statesize,10, leakyrelu),
    Dense(10,10,leakyrelu),
    Dense(10,1))

I end up with the array deltas and I would like to perform a batch gradient update (batch size = 19) on a second neural network, where each gradient gets weighted by an appropriate delta. The update function I wrote is

function vupdate2!(S_batch,model,α,deltas)

   function v_loss_total(x)
       return sum(reshape(deltas,(1,19)) .* model(x))
   end

   local ps = Flux.params(model)
   local gs = Flux.Tracker.gradient(() -> v_loss_total(S_batch), ps)
   for p in ps
       Flux.Tracker.update!( p,  α.* gs[p])
   end
end

The problem is, that the line where the gradients are being calculated throw an error: MethodError: no method matching Float32(::Tracker.TrackedReal{Float64})

I think the problem is, that my delta array is tracked. Looking at the output of the v_loss_total function for a random input, I get:

julia> v_loss_total(S_batch)
-6752.433690476287 (tracked) (tracked)

Interestingly, this number is tracked twice (?) which I guess comes from multiplying two tracked numbers together (i.e. the entries of deltas and model(S_batch)). Is there a way to first untrack the delta array? I would appreciate any help.


Solution

  • Okay, as it turns out, there is a function

    Flux.Tracker.data()
    

    which does exactly what I needed. It takes a tracked number and returns the Float itself. Also see: https://github.com/FluxML/Flux.jl/issues/640