Search code examples
deep-learningpriority-queuereinforcement-learningq-learning

prioritized experience replay in deep Q-learning


i was implementing DQN in mountain car problem of openai gym. this problem is special as the positive reward is very sparse. so i thought of implementing prioritized experience replay as proposed in this paper by google deep mind.

there are certain things that are confusing me:

  • how do we store the replay memory. i get that pi is the priority of transition and there are two ways but what is this P(i)?
  • if we follow the rules given won't P(i) change every time a sample is added.
  • what does it mean when it says "we sample according to this probability distribution". what is the distribution.
  • finally how do we sample from it. i get that if we store it in a priority queue we can sample directly but we are actually storing it in a sum tree.

thanks in advance


Solution

    • According to the paper, there are two ways for calculating Pi and base on your choice, your implementation differs. I assume you selected Proportional Prioriziation then you should use "sum-tree" data structure for storing a pair of transition and P(i). P(i) is just the normalized version of Pi and it shows how important that transition is or in other words how effective that transition is for improving your network. When P(i) is high, it means it's so surprising for the network so it can really help the network to tune itself.
    • You should add each new transition with infinity priority to make sure it will be played at least once and there is no need to update all the experience replay memory for each new coming transition. During the experience replay process, you select a mini-batch and update the probability of those experiences in the mini-batch.
    • Each experience has a probability so all of the experiences together make a distribution and we select our next mini-batch according to this distribution.
    • You can sample via this policy from your sum-tree:
    def retrieve(n, s):
        if n is leaf_node: return n
        if n.left.val >= s: return retrieve(n.left, s)
        else: return retrieve(n.right, s - n.left.val)
    

    I have taken the code from here.