Search code examples
pythontensorflowreinforcement-learningtf-agent

How to pass the batchsize for a custom environment in Tf-agents


I am using tf-agents library to build a contextual bandit. For this I am building a custom environment.
I am creating a banditpyenvironment and wrapping it in the TFpyenvironment.

The tfpyenvironment automatically adds the batch size dimension (in observation spec). I need to account for this batch size dimension in the _observe and _apply_Action methods. Since depending on the batch size, I should provide the required (batch size) number of observations (for observe) and also as per batch size, I should take in batch size number of actions and should provide the rewards(for apply action).

I am unable to find a single example on how to tell the tfenvironment what the batch size, without letting automatically add a 1 to the first dimension. Can someone please clarify

 def __init__(self, batch_size):

    self.batchsize=batch_size
    observation_spec = BoundedTensorSpec(
    (2,), np.int32, minimum=[1,1], maximum=[5,2], name= 'observation')
    action_spec = BoundedTensorSpec(
        shape=(), dtype=np.int32, minimum=0, maximum=6, name='action')


    super(SampleEnvironment, self).__init__(observation_spec, action_spec)

  def _observe(self):
    batch=[]
    for i in range(self.batchsize):
        each=tf.cast(np.array([np.random.choice([1,2,3,4,5]),np.random.choice([1,2])]), 'int32')
        batch.append(each)
    self.observation=np.array(batch)
    print("in observe",self.observation)
    return np.array(self.observation)

When I try to somehow account for the batchsize in the observe method like above (using a for loop for the batch size), the tfenvironment is again adding 1 to the first dimension as batchsize. Is there a way to automatically tell the environment that the batch is say 3, instead of it automatically adding 1. At the same time, how would I account for this batch size in replay buffer and agents


Solution

  • This can be done using the BatchedPyEnvironment class as show in the example below. Looks like the bandit environment from above is a non batched environment.

    SampleEnvironment in below is the banditpyenvironment which is shown in the question

    batch_size = 4
    env= SampleEnvironment()
    py_envs = [env for _ in range(0, batch_size)]
    batched_env = batched_py_environment.BatchedPyEnvironment(envs=py_envs)
    tfenv = tf_py_environment.TFPyEnvironment(batched_env)