Search code examples
pythontensorflowreinforcement-learningopenai-gymtf-agent

tf_agents and reverb produce incompatible tensor


I'm trying to implement a DDPG using tf_agents and reverb but I can't figure out how both libraries to work together. For this, I'm trying to use the code from the DQL-Tutorial from tf_agents with my own agent and gym environment. The error occurs when I try to retrieve data from reverb and the tensor shape doesn't match. I've created the smallest possible example I could think of, to show the problem:

Imports

import gym
from gym import spaces
from gym.utils.env_checker import check_env
from gym.envs.registration import register

import tensorflow as tf
import numpy as np
import reverb

from tf_agents.agents import DdpgAgent
from tf_agents.drivers.py_driver import PyDriver
from tf_agents.environments import TFPyEnvironment, suite_gym, validate_py_environment
from tf_agents.networks import Sequential
from tf_agents.policies import PyTFEagerPolicy
from tf_agents.replay_buffers import ReverbReplayBuffer, ReverbAddTrajectoryObserver
from tf_agents.specs import tensor_spec, BoundedArraySpec

Example Gym environment

class TestGym(gym.Env):
    metadata = {"render_modes": ["human"]}
    def __init__(self):
        self.observation_space = spaces.Box(low=-1, high=1, shape=(30,), dtype=np.float32)
        self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32)
        self.__count = 0
    def step(self, action):
        self.__count += 1
        return np.zeros(30, dtype=np.float32), 0, self.__count >= 100, {}
    def render(self, mode="human"):
        return None
    def reset(self, seed=None, return_info=False, options=None):
        super().reset(seed=seed, options=options)
        self.__count = 0
        if return_info:
            return np.zeros(30, dtype=np.float32), {}
        else:
            return np.zeros(30, dtype=np.float32)

register(
    id="TestGym-v0",
    entry_point="reverb_test:TestGym",
    nondeterministic=False
)

Creating a TFAgent and use reverb to store and retrieve

def main():
    # make sure the gym environment is ok
    check_env(gym.make("TestGym-v0"))

    # create tf-py-environment
    env = TFPyEnvironment(suite_gym.load("TestGym-v0"))

    # make sure the py environment is ok
    validate_py_environment(env.pyenv, episodes=5)

    # example actor network
    actor_network = Sequential([
        tf.keras.layers.Dense(40),
        tf.keras.layers.Dense(2, activation=None)
    ], input_spec=env.observation_spec())

    # example critic network
    n_actions = env.action_spec().shape[0]
    n_observ = env.observation_spec().shape[0]
    critic_input_spec: BoundedArraySpec = BoundedArraySpec((n_actions + n_observ,), "float32", minimum=-1, maximum=1)
    critic_network = Sequential([
        tf.keras.layers.Dense(40),
        tf.keras.layers.Dense(1, activation=None)
    ], input_spec=critic_input_spec)

    # example rl agent
    agent = DdpgAgent(
        time_step_spec=env.time_step_spec(),
        action_spec=env.action_spec(),
        actor_network=actor_network,
        critic_network=critic_network,
    )

    # create reverb table
    table_name = "uniform_table"
    replay_buffer_signature = tensor_spec.from_spec(agent.collect_data_spec)
    replay_buffer_signature = tensor_spec.add_outer_dim(replay_buffer_signature)
    table = reverb.Table(
        table_name,
        max_size=100_000,
        sampler=reverb.selectors.Uniform(),
        remover=reverb.selectors.Fifo(),
        rate_limiter=reverb.rate_limiters.MinSize(1),
        signature=replay_buffer_signature
    )

    # create reverb server
    reverb_server = reverb.Server([table])

    # create replay buffer for this table and server
    replay_buffer = ReverbReplayBuffer(
        agent.collect_data_spec,
        table_name=table_name,
        sequence_length=2,
        local_server=reverb_server
    )

    # create observer to store experiences
    observer = ReverbAddTrajectoryObserver(
        replay_buffer.py_client,
        table_name,
        sequence_length=2
    )

    # run a view steps to ill the replay buffer
    driver = PyDriver(env.pyenv, PyTFEagerPolicy(agent.collect_policy, use_tf_function=True), [observer], max_steps=100)
    driver.run(env.reset())

    # create a dataset to access the replay buffer
    dataset = replay_buffer.as_dataset(num_parallel_calls=3, sample_batch_size=20, num_steps=2).prefetch(3)
    iterator = iter(dataset)

    # retrieve a sample
    print(next(iterator)) # <===== ERROR


if __name__ == '__main__':
    main()

When I run this code, I get the following error Message:

tensorflow.python.framework.errors_impl.InvalidArgumentError:
{{function_node __wrapped__IteratorGetNext_output_types_11_device_/job:localhost/replica:0/task:0/device:CPU:0}}
Received incompatible tensor at flattened index 0 from table 'uniform_table'.
    Specification has (dtype, shape): (int32, [?]).
    Tensor has        (dtype, shape): (int32, [2,1]).
Table signature:
    0: Tensor<name: 'step_type/step_type', dtype: int32, shape: [?]>,
    1: Tensor<name: 'observation/observation', dtype: float, shape: [?,30]>,
    2: Tensor<name: 'action/action', dtype: float, shape: [?,2]>,
    3: Tensor<name: 'next_step_type/step_type', dtype: int32, shape: [?]>,
    4: Tensor<name: 'reward/reward', dtype: float, shape: [?]>,
    5: Tensor<name: 'discount/discount', dtype: float, shape: [?]>
[Op:IteratorGetNext]

In my gym environment, I defined the action space as a 2-element vector and I'm guessing that this action vector is somehow the problem. I've tried to use tensor specs for every input and output but I guess I made a mistake somewhere. Does anyone have an Idea what I'm doing wrong here?


Solution

  • I finally figured it out:

    PyDriver needs a PyEnvironment to work properly. In my code I used the pyenv attribute of my TFPyEnvironment which, despite of its name, doesn't return a regular PyEnvironment but a batched one insted.

    Changing the code in the following way fixes this issue:

    ...
    
    def main():
        # make sure the gym environment is ok
        check_env(gym.make("TestGym-v0"))
    
        # create py-environment
        pyenv = suite_gym.load("TestGym-v0")  # <=============
    
        # create tf-py-environment
        env = TFPyEnvironment(pyenv)
    
        ...
    
        driver = PyDriver(py_env, PyTFEagerPolicy(agent.collect_policy, use_tf_function=True), [observer], max_steps=100)
        driver.run(py_env.reset())
    
        ...