I am using tf-agents library to build a contextual bandit.
For this I am building a custom environment.
I am creating a banditpyenvironment and wrapping it in the TFpyenvironment.
The tfpyenvironment automatically adds the batch size dimension (in observation spec). I need to account for this batch size dimension in the _observe and _apply_Action methods. Since depending on the batch size, I should provide the required (batch size) number of observations (for observe) and also as per batch size, I should take in batch size number of actions and should provide the rewards(for apply action).
I am unable to find a single example on how to tell the tfenvironment what the batch size, without letting automatically add a 1 to the first dimension. Can someone please clarify
def __init__(self, batch_size):
self.batchsize=batch_size
observation_spec = BoundedTensorSpec(
(2,), np.int32, minimum=[1,1], maximum=[5,2], name= 'observation')
action_spec = BoundedTensorSpec(
shape=(), dtype=np.int32, minimum=0, maximum=6, name='action')
super(SampleEnvironment, self).__init__(observation_spec, action_spec)
def _observe(self):
batch=[]
for i in range(self.batchsize):
each=tf.cast(np.array([np.random.choice([1,2,3,4,5]),np.random.choice([1,2])]), 'int32')
batch.append(each)
self.observation=np.array(batch)
print("in observe",self.observation)
return np.array(self.observation)
When I try to somehow account for the batchsize in the observe method like above (using a for loop for the batch size), the tfenvironment is again adding 1 to the first dimension as batchsize. Is there a way to automatically tell the environment that the batch is say 3, instead of it automatically adding 1. At the same time, how would I account for this batch size in replay buffer and agents
This can be done using the BatchedPyEnvironment class as show in the example below. Looks like the bandit environment from above is a non batched environment.
SampleEnvironment in below is the banditpyenvironment which is shown in the question
batch_size = 4
env= SampleEnvironment()
py_envs = [env for _ in range(0, batch_size)]
batched_env = batched_py_environment.BatchedPyEnvironment(envs=py_envs)
tfenv = tf_py_environment.TFPyEnvironment(batched_env)