Search code examples
juliareinforcement-learning

Using neural network approximator in reinforcementlearning.jl


I am trying to create a simultaneous multi agent environment using reinforcementlearning.jl
I have successfully represented the environment and it works with a RandomPolicy for every agent.

But my state space is large (actually it's a 14 tuple with each value in a certain range). So I can not use Tabular Approximators to estimate the Q or V values. That's why I have decided to use a Neural Network Approximator. But the docs do not discuss much about it, nor are there any examples were neural network approximator is used. I am stuck how to figure out how to use such approximator. If anyone can explain how to go about it, or refer to any example, it would be helpful.

Moreover I found from docs that using a Neural Network approximator needs us to use a CircularArraySARTTrajectory. But defining this trajectory requires a key word argument called capacity. I don't know what it means, nor it is discussed about in the docs and GitHub.

I tried writing the code that uses neural network approximator but I get error.

# Create a flux based DNN for q - value estimation
STATE_SIZE = length(myenv.channels) # 14
ACTION_SIZE = length(values)        # 2

model = Chain(
      Dense(STATE_SIZE, 24, tanh),
      Dense(24, 48, tanh),
      Dense(48, ACTION_SIZE)
  ) |> gpu

η = 1f-2 # Learning rate
η_decay = 1f-3
opt = Flux.Optimiser(ADAM(η), InvDecay(η_decay))

policies = MultiAgentManager(
   (
       Agent(
           policy = NamedPolicy(
              p => VBasedPolicy(;
                  learner = BasicDQNLearner(;
                     approximator = NeuralNetworkApproximator(;
                         model = model,
                         optimizer = opt
                    )
                 )
             )
         ),
         trajectory = CircularArraySARTTrajectory(;
             capacity = 14,
             state=Array{Float64, 1},
             action=Int,
             terminal = Bool
         )
     )
     for p in players(myenv)
 )...
)

Error/StackTrace

MethodError: no method matching iterate(::Type{Array{Float64,1}})
Closest candidates are:
  iterate(::Plots.NaNSegmentsIterator) at 
C:\Users\vchou\.julia\packages\Plots\lzHOt\src\utils.jl:124
  iterate(::Plots.NaNSegmentsIterator, ::Int64) at 
C:\Users\vchou\.julia\packages\Plots\lzHOt\src\utils.jl:124
  iterate(::LibGit2.GitBranchIter) at 
C:\buildbot\worker\package_win64\build\usr\share\julia\stdlib\v1.5\LibGit2\src\reference.jl:343
  ...

Stacktrace:
 [1] first(::Type{T} where T) at .\abstractarray.jl:341
 [2] (::ReinforcementLearningCore.var"#53#54"{Int64})(::Type{T} where T) at C:\Users\vchou\.julia\packages\ReinforcementLearningCore\NWrFY\src\policies\agents\trajectories\trajectory.jl:46
 [3] map(::ReinforcementLearningCore.var"#53#54"{Int64}, ::Tuple{DataType,DataType}) at .\tuple.jl:158
 [4] map(::Function, ::NamedTuple{(:state, :action),Tuple{DataType,DataType}}) at .\namedtuple.jl:187
 [5] CircularArrayTrajectory(; capacity::Int64, kwargs::Base.Iterators.Pairs{Symbol,DataType,Tuple{Symbol,Symbol},NamedTuple{(:state, :action),Tuple{DataType,DataType}}}) at C:\Users\vchou\.julia\packages\ReinforcementLearningCore\NWrFY\src\policies\agents\trajectories\trajectory.jl:45
 [6] Trajectory{var"#s57"} where var"#s57"<:(NamedTuple{(:state, :action, :reward, :terminal),var"#s16"} where var"#s16"<:(Tuple{var"#s15",var"#s14",var"#s12",var"#s84"} where var"#s84"<:CircularArrayBuffers.CircularArrayBuffer where var"#s12" <:CircularArrayBuffers.CircularArrayBuffer where var"#s14"<:CircularArrayBuffers.CircularArrayBuffer where var"#s15"<:CircularArrayBuffers.CircularArrayBuffer))(; capacity::Int64, state::Type{T} where T, action::Type{T} where T, reward::Pair{DataType,Tuple{}}, terminal::Type{T} where T) at C:\Users\vchou\.julia\packages\ReinforcementLearningCore\NWrFY\src\policies\agents\trajectories\trajectory.jl:76
 [7] (::var"#24#25")(::String) at .\none:0
 [8] iterate(::Base.Generator{Array{String,1},var"#24#25"}) at .\generator.jl:47
 [9] top-level scope at In[18]:15
 [10] include_string(::Function, ::Module, ::String, ::String) at .\loading.jl:1091

Solution

  • Here the capacity means the maximum length of the experience replay buffer. When applying DQN related algorithms, we usually use a circular buffer to store transitions at each step.

    The error you posted above means that you forget to define the size of the state when defining the CircularArraySARTTrajectory.

    -              state=Array{Float64, 1},
    +              state=Array{Float64, 1} => (STATE_SIZE,),
    

    You can find some example usage here. I'd suggest you create an issue in that package because the docstring of CircularArraySARTTrajectory definitely should be included in the doc.