Search code examples
tensorflowgpureinforcement-learning

Using BatchedPyEnvironment in tf_agents


I am trying to create a batched environment version of an SAC agent example from the Tensorflow Agents library, the original code can be found here. I am also using a custom environment.

I am pursuing a batched environment setup in order to better leverage GPU resources in order to speed up training. My understanding is that by passing batches of trajectories to the GPU, there will be less overhead incurred when passing data from the host (CPU) to the device (GPU).

My custom environment is called SacEnv, and I attempt to create a batched environment like so:

py_envs = [SacEnv() for _ in range(0, batch_size)]
batched_env = batched_py_environment.BatchedPyEnvironment(envs=py_envs)
tf_env = tf_py_environment.TFPyEnvironment(batched_env)

My hope is that this will create a batched environment consisting of a 'batch' of non-batched environments. However I am receiving the following error when running the code:

ValueError: Cannot assign value to variable ' Accumulator:0': Shape mismatch.The variable shape (1,), and the assigned value shape (32,) are incompatible.

with the stack trace:

Traceback (most recent call last):
  File "/home/gary/Desktop/code/sac_test/sac_main2.py", line 370, in <module>
    app.run(main)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/absl/app.py", line 312, in run
    _run_main(main, args)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/absl/app.py", line 258, in _run_main
    sys.exit(main(argv))
  File "/home/gary/Desktop/code/sac_test/sac_main2.py", line 366, in main
    train_eval(FLAGS.root_dir)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/gary/Desktop/code/sac_test/sac_main2.py", line 274, in train_eval
    results = metric_utils.eager_compute(
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/gin/config.py", line 1605, in gin_wrapper
    utils.augment_exception_message_and_reraise(e, err_str)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
    raise proxy.with_traceback(exception.__traceback__) from None
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/gin/config.py", line 1582, in gin_wrapper
    return fn(*new_args, **new_kwargs)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/eval/metric_utils.py", line 163, in eager_compute
    common.function(driver.run)(time_step, policy_state)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
    raise e.with_traceback(filtered_tb) from None
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/drivers/dynamic_episode_driver.py", line 211, in run
    return self._run_fn(
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/utils/common.py", line 188, in with_check_resource_vars
    return fn(*fn_args, **fn_kwargs)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/drivers/dynamic_episode_driver.py", line 238, in _run
    tf.while_loop(
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/drivers/dynamic_episode_driver.py", line 154, in loop_body
    observer_ops = [observer(traj) for observer in self._observers]
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/drivers/dynamic_episode_driver.py", line 154, in <listcomp>
    observer_ops = [observer(traj) for observer in self._observers]
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/metrics/tf_metric.py", line 93, in __call__
    return self._update_state(*args, **kwargs)
  File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/metrics/tf_metric.py", line 81, in _update_state
    return self.call(*arg, **kwargs)
ValueError: in user code:

    File "/home/gary/anaconda3/envs/py39/lib/python3.9/site-packages/tf_agents/metrics/tf_metrics.py", line 176, in call  *
        self._return_accumulator.assign(

    ValueError: Cannot assign value to variable ' Accumulator:0': Shape mismatch.The variable shape (1,), and the assigned value shape (32,) are incompatible.

  In call to configurable 'eager_compute' (<function eager_compute at 0x7fa4d6e5e040>)
  In call to configurable 'train_eval' (<function train_eval at 0x7fa4c8622dc0>)

I have dug through the tf_metric.py code to try and understand the error, however I have been unsuccessful. A related issue was solved when I added the batch size (32) to the initializer for the AverageReturnMetric instance, and this issue seems related.

The full code is:

# coding=utf-8
# Copyright 2020 The TF-Agents Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Lint as: python2, python3

r"""Train and Eval SAC.

All hyperparameters come from the SAC paper
https://arxiv.org/pdf/1812.05905.pdf

To run:

```bash
tensorboard --logdir $HOME/tmp/sac/gym/HalfCheetah-v2/ --port 2223 &

python tf_agents/agents/sac/examples/v2/train_eval.py \
  --root_dir=$HOME/tmp/sac/gym/HalfCheetah-v2/ \
  --alsologtostderr
\```
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from sac_env import SacEnv

import os
import time

from absl import app
from absl import flags
from absl import logging

import gin
from six.moves import range
import tensorflow as tf  # pylint: disable=g-explicit-tensorflow-version-import

from tf_agents.agents.ddpg import critic_network
from tf_agents.agents.sac import sac_agent
from tf_agents.agents.sac import tanh_normal_projection_network
from tf_agents.drivers import dynamic_step_driver
#from tf_agents.environments import suite_mujoco
from tf_agents.environments import tf_py_environment
from tf_agents.environments import batched_py_environment
from tf_agents.eval import metric_utils
from tf_agents.metrics import tf_metrics
from tf_agents.networks import actor_distribution_network
from tf_agents.policies import greedy_policy
from tf_agents.policies import random_tf_policy
from tf_agents.replay_buffers import tf_uniform_replay_buffer
from tf_agents.utils import common
from tf_agents.train.utils import strategy_utils


flags.DEFINE_string('root_dir', os.getenv('TEST_UNDECLARED_OUTPUTS_DIR'),
                    'Root directory for writing logs/summaries/checkpoints.')
flags.DEFINE_multi_string('gin_file', None, 'Path to the trainer config files.')
flags.DEFINE_multi_string('gin_param', None, 'Gin binding to pass through.')

FLAGS = flags.FLAGS

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

@gin.configurable
def train_eval(
    root_dir,
    env_name='SacEnv',
    # The SAC paper reported:
    # Hopper and Cartpole results up to 1000000 iters,
    # Humanoid results up to 10000000 iters,
    # Other mujoco tasks up to 3000000 iters.
    num_iterations=3000000,
    actor_fc_layers=(256, 256),
    critic_obs_fc_layers=None,
    critic_action_fc_layers=None,
    critic_joint_fc_layers=(256, 256),
    # Params for collect
    # Follow https://github.com/haarnoja/sac/blob/master/examples/variants.py
    # HalfCheetah and Ant take 10000 initial collection steps.
    # Other mujoco tasks take 1000.
    # Different choices roughly keep the initial episodes about the same.
    #initial_collect_steps=10000,
    initial_collect_steps=2000,
    collect_steps_per_iteration=1,
    replay_buffer_capacity=31250, # 1000000 / 32
    # Params for target update
    target_update_tau=0.005,
    target_update_period=1,
    # Params for train
    train_steps_per_iteration=1,
    #batch_size=256,
    batch_size=32,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    td_errors_loss_fn=tf.math.squared_difference,
    gamma=0.99,
    reward_scale_factor=0.1,
    gradient_clipping=None,
    use_tf_functions=True,
    # Params for eval
    num_eval_episodes=30,
    eval_interval=10000,
    # Params for summaries and logging
    train_checkpoint_interval=50000,
    policy_checkpoint_interval=50000,
    rb_checkpoint_interval=50000,
    log_interval=1000,
    summary_interval=1000,
    summaries_flush_secs=10,
    debug_summaries=False,
    summarize_grads_and_vars=False,
    eval_metrics_callback=None):
  """A simple train and eval for SAC."""
  root_dir = os.path.expanduser(root_dir)
  train_dir = os.path.join(root_dir, 'train')
  eval_dir = os.path.join(root_dir, 'eval')

  train_summary_writer = tf.compat.v2.summary.create_file_writer(
      train_dir, flush_millis=summaries_flush_secs * 1000)
  train_summary_writer.set_as_default()

  eval_summary_writer = tf.compat.v2.summary.create_file_writer(
      eval_dir, flush_millis=summaries_flush_secs * 1000)
  eval_metrics = [
      tf_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
      tf_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes)
  ]

  global_step = tf.compat.v1.train.get_or_create_global_step()
  with tf.compat.v2.summary.record_if(
      lambda: tf.math.equal(global_step % summary_interval, 0)):


    py_envs = [SacEnv() for _ in range(0, batch_size)]
    batched_env = batched_py_environment.BatchedPyEnvironment(envs=py_envs)
    tf_env = tf_py_environment.TFPyEnvironment(batched_env)
    
    eval_py_envs = [SacEnv() for _ in range(0, batch_size)]
    eval_batched_env = batched_py_environment.BatchedPyEnvironment(envs=eval_py_envs)
    eval_tf_env = tf_py_environment.TFPyEnvironment(eval_batched_env)

    time_step_spec = tf_env.time_step_spec()
    observation_spec = time_step_spec.observation
    action_spec = tf_env.action_spec()

    strategy = strategy_utils.get_strategy(tpu=False, use_gpu=True)

    with strategy.scope():
        actor_net = actor_distribution_network.ActorDistributionNetwork(
            observation_spec,
            action_spec,
            fc_layer_params=actor_fc_layers,
            continuous_projection_net=tanh_normal_projection_network
            .TanhNormalProjectionNetwork)
        critic_net = critic_network.CriticNetwork(
            (observation_spec, action_spec),
            observation_fc_layer_params=critic_obs_fc_layers,
            action_fc_layer_params=critic_action_fc_layers,
            joint_fc_layer_params=critic_joint_fc_layers,
            kernel_initializer='glorot_uniform',
            last_kernel_initializer='glorot_uniform')

        tf_agent = sac_agent.SacAgent(
            time_step_spec,
            action_spec,
            actor_network=actor_net,
            critic_network=critic_net,
            actor_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=actor_learning_rate),
            critic_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=critic_learning_rate),
            alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
                learning_rate=alpha_learning_rate),
            target_update_tau=target_update_tau,
            target_update_period=target_update_period,
            td_errors_loss_fn=td_errors_loss_fn,
            gamma=gamma,
            reward_scale_factor=reward_scale_factor,
            gradient_clipping=gradient_clipping,
            debug_summaries=debug_summaries,
            summarize_grads_and_vars=summarize_grads_and_vars,
            train_step_counter=global_step)
    tf_agent.initialize()

    # Make the replay buffer.
    replay_buffer = tf_uniform_replay_buffer.TFUniformReplayBuffer(
        data_spec=tf_agent.collect_data_spec,
        batch_size=batch_size,
        max_length=replay_buffer_capacity,
        device="/device:GPU:0")
    replay_observer = [replay_buffer.add_batch]

    train_metrics = [
        tf_metrics.NumberOfEpisodes(),
        tf_metrics.EnvironmentSteps(),
        tf_metrics.AverageReturnMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
        tf_metrics.AverageEpisodeLengthMetric(
            buffer_size=num_eval_episodes, batch_size=tf_env.batch_size),
    ]

    eval_policy = greedy_policy.GreedyPolicy(tf_agent.policy)
    initial_collect_policy = random_tf_policy.RandomTFPolicy(
        tf_env.time_step_spec(), tf_env.action_spec())
    collect_policy = tf_agent.collect_policy

    train_checkpointer = common.Checkpointer(
        ckpt_dir=train_dir,
        agent=tf_agent,
        global_step=global_step,
        metrics=metric_utils.MetricsGroup(train_metrics, 'train_metrics'))
    policy_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'policy'),
        policy=eval_policy,
        global_step=global_step)
    rb_checkpointer = common.Checkpointer(
        ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
        max_to_keep=1,
        replay_buffer=replay_buffer)

    train_checkpointer.initialize_or_restore()
    rb_checkpointer.initialize_or_restore()

    initial_collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        initial_collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=initial_collect_steps)

    collect_driver = dynamic_step_driver.DynamicStepDriver(
        tf_env,
        collect_policy,
        observers=replay_observer + train_metrics,
        num_steps=collect_steps_per_iteration)

    if use_tf_functions:
      initial_collect_driver.run = common.function(initial_collect_driver.run)
      collect_driver.run = common.function(collect_driver.run)
      tf_agent.train = common.function(tf_agent.train)

    if replay_buffer.num_frames() == 0:
      # Collect initial replay data.
      logging.info(
          'Initializing replay buffer by collecting experience for %d steps '
          'with a random policy.', initial_collect_steps)
      initial_collect_driver.run()

    results = metric_utils.eager_compute(
        eval_metrics,
        eval_tf_env,
        eval_policy,
        num_episodes=num_eval_episodes,
        train_step=global_step,
        summary_writer=eval_summary_writer,
        summary_prefix='Metrics',
    )
    if eval_metrics_callback is not None:
      eval_metrics_callback(results, global_step.numpy())
    metric_utils.log_metrics(eval_metrics)

    time_step = None
    policy_state = collect_policy.get_initial_state(tf_env.batch_size)

    timed_at_step = global_step.numpy()
    time_acc = 0

    # Prepare replay buffer as dataset with invalid transitions filtered.
    def _filter_invalid_transition(trajectories, unused_arg1):
      return ~trajectories.is_boundary()[0]
    dataset = replay_buffer.as_dataset(
        sample_batch_size=batch_size,
        num_steps=2).unbatch().filter(
            _filter_invalid_transition).batch(batch_size).prefetch(5)
    # Dataset generates trajectories with shape [Bx2x...]
    iterator = iter(dataset)

    def train_step():
      experience, _ = next(iterator)
      return tf_agent.train(experience)

    if use_tf_functions:
      train_step = common.function(train_step)

    global_step_val = global_step.numpy()
    while global_step_val < num_iterations:
      start_time = time.time()
      time_step, policy_state = collect_driver.run(
          time_step=time_step,
          policy_state=policy_state,
      )
      for _ in range(train_steps_per_iteration):
        train_loss = train_step()
      time_acc += time.time() - start_time

      global_step_val = global_step.numpy()

      if global_step_val % log_interval == 0:
        logging.info('step = %d, loss = %f', global_step_val,
                     train_loss.loss)
        steps_per_sec = (global_step_val - timed_at_step) / time_acc
        logging.info('%.3f steps/sec', steps_per_sec)
        tf.compat.v2.summary.scalar(
            name='global_steps_per_sec', data=steps_per_sec, step=global_step)
        timed_at_step = global_step_val
        time_acc = 0

      for train_metric in train_metrics:
        train_metric.tf_summaries(
            train_step=global_step, step_metrics=train_metrics[:2])

      if global_step_val % eval_interval == 0:
        results = metric_utils.eager_compute(
            eval_metrics,
            eval_tf_env,
            eval_policy,
            num_episodes=num_eval_episodes,
            train_step=global_step,
            summary_writer=eval_summary_writer,
            summary_prefix='Metrics',
        )
        if eval_metrics_callback is not None:
          eval_metrics_callback(results, global_step_val)
        metric_utils.log_metrics(eval_metrics)

      if global_step_val % train_checkpoint_interval == 0:
        train_checkpointer.save(global_step=global_step_val)

      if global_step_val % policy_checkpoint_interval == 0:
        policy_checkpointer.save(global_step=global_step_val)

      if global_step_val % rb_checkpoint_interval == 0:
        rb_checkpointer.save(global_step=global_step_val)
    return train_loss


def main(_):
  tf.compat.v1.enable_v2_behavior()
  logging.set_verbosity(logging.INFO)
  gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_param)
  train_eval(FLAGS.root_dir)

if __name__ == '__main__':
  flags.mark_flag_as_required('root_dir')
  app.run(main)

What is the appropriate way to create a batched environment for a custom, non-batched environment? I can share my custom environment, but I don't believe the issue lies there as the code works fine when using batch sizes of 1.

Also, any tips on increasing GPU utilization in reinforcement learning scenarios would be greatly appreciated. I have examined examples of using tensorboard-profiler to profile GPU utilization, but it seems these require callbacks and a fit function, which doesn't seem to be applicable in RL use-cases.


Solution

  • It turns out I neglected to pass batch_size when initializing the AverageReturnMetric and AverageEpisodeLengthMetric instances.