I'm trying to implement Temporal attention in a Reinforcement Learning problem using Stable baselines however, I keep getting the mentioned error in the customer policy. I am using TensorFlow version 1.14. While using an LSTMCell along with RNN class from TensorFlow in my policy.py, I am also initializing a wrapper for attention but I keep getting the following error.
Traceback (most recent call last):
File "run.py", line 60, in <module>
trainedModel = model_training(featureMatrix, config['env_name'], config['number_of_cpus'], config['total_training_timesteps'], config['policy'])
File "/code/src/util/utils.py", line 88, in model_training
trained_model = trained_model.train()
File "/code/src/util/model/model_training.py", line 103, in train
tensorboard_log=self.tensorboard_path).learn(total_timesteps=self.total_training_timesteps, callback=self.callback)
File "/venv/lib/python3.7/site-packages/stable_baselines/acktr/acktr.py", line 119, in __init__
self.setup_model()
File "/venv/lib/python3.7/site-packages/stable_baselines/acktr/acktr.py", line 148, in setup_model
1, n_batch_step, reuse=False, **self.policy_kwargs)
File "/code/src/util/policy/policy.py", line 97, in __init__
rnn = tf.keras.layers.RNN(self._build_rnn_cell())
File "/code/src/util/policy/policy.py", line 165, in _build_rnn_cell
return tf.keras.layers.StackedRNNCells([self._build_single_cell() for _ in range(3)])
File "/code/src/util/policy/policy.py", line 165, in <listcomp>
return tf.keras.layers.StackedRNNCells([self._build_single_cell() for _ in range(3)])
File "/code/src/util/policy/policy.py", line 158, in _build_single_cell
128,
File "/code/src/util/policy/attention_wrapper.py", line 123, in __init__
super(TemporalPatternAttentionCellWrapper, self).__init__(_reuse=reuse)
TypeError: __init__() missing 1 required positional argument: 'units'
My policy.py is as follows:
class CustomPolicy(ActorCriticPolicy):
def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, **kwargs):
super(CustomPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse, scale=True)
with tf.variable_scope("model", reuse=reuse):
rnn = tf.keras.layers.RNN(self._build_rnn_cell())
feature_layer = rnn(self.processed_obs)
pi_layers = Sequential([
Dense(128, input_shape = (256,),
kernel_regularizer=regularizers.l2(0.01), activity_regularizer=regularizers.l1(0.01)),
Activation('relu'),
Dense(128, kernel_regularizer=regularizers.l2(0.01), activity_regularizer=regularizers.l1(0.01))
])
pi_latent = pi_layers(feature_layer)
vf_layers = Sequential([
Dense(32, input_shape = (256,),
kernel_regularizer=regularizers.l2(0.01), activity_regularizer=regularizers.l1(0.01)),
Activation('relu'),
Dense(32, kernel_regularizer=regularizers.l2(0.01), activity_regularizer=regularizers.l1(0.01))
])
vf_latent = vf_layers(feature_layer)
temp_value_fn = Dense(1, input_shape=(32,))
value_fn = temp_value_fn(vf_latent)
self._proba_distribution, self._policy, self.q_value = \
self.pdtype.proba_distribution_from_latent(pi_latent, vf_latent, init_scale=0.01)
self._value_fn = value_fn
self._setup_init()
def step(self, obs, state=None, mask=None, deterministic=False):
if deterministic:
action, value, neglogp = self.sess.run([self.deterministic_action, self.value_flat, self.neglogp],
{self.obs_ph: obs})
else:
action, value, neglogp = self.sess.run([self.action, self.value_flat, self.neglogp],
{self.obs_ph: obs})
return action, value, self.initial_state, neglogp
def proba_step(self, obs, state=None, mask=None):
return self.sess.run(self.policy_proba, {self.obs_ph: obs})
def value(self, obs, state=None, mask=None):
return self.sess.run(self.value_flat, {self.obs_ph: obs})
def _build_single_cell(self):
cell = tf.keras.layers.LSTMCell(256)
cell = TemporalPatternAttentionCellWrapper(
cell,
128,
)
return cell
def _build_rnn_cell(self):
return tf.keras.layers.StackedRNNCells([self._build_single_cell() for _ in range(3)])
and my Attention Wrapper is as follows:
class TemporalPatternAttentionCellWrapper(tf.keras.layers.LSTMCell):
def __init__(self,
cell,
attn_length,
units=256,
attn_size=None,
attn_vec_size=None,
input_size=None,
state_is_tuple=True,
reuse=None):
"""Create a cell with attention.
Args:
cell: an RNNCell, an attention is added to it.
attn_length: integer, the size of an attention window.
attn_size: integer, the size of an attention vector. Equal to
cell.output_size by default.
attn_vec_size: integer, the number of convolutional features
calculated on attention state and a size of the hidden layer
built from base cell state. Equal attn_size to by default.
input_size: integer, the size of a hidden linear layer, built from
inputs and attention. Derived from the input tensor by default.
state_is_tuple: If True, accepted and returned states are n-tuples,
where `n = len(cells)`. By default (False), the states are all
concatenated along the column axis.
reuse: (optional) Python boolean describing whether to reuse
variables in an existing scope. If not `True`, and the existing
scope already has the given variables, an error is raised.
Raises:
TypeError: if cell is not an RNNCell.
ValueError: if cell returns a state tuple but the flag
`state_is_tuple` is `False` or if attn_length is zero or less.
"""
super(TemporalPatternAttentionCellWrapper, self).__init__(_reuse=reuse)
if nest.is_sequence(cell.state_size) and not state_is_tuple:
raise ValueError("Cell returns tuple of states, but the flag "
"state_is_tuple is not set. State size is: %s" %
str(cell.state_size))
if attn_length <= 0:
raise ValueError("attn_length should be greater than zero, got %s"
% str(attn_length))
if not state_is_tuple:
logging.warn(
"%s: Using a concatenated state is slower and will soon be "
"deprecated. Use state_is_tuple=True.", self)
if attn_size is None:
attn_size = 2880
if attn_vec_size is None:
attn_vec_size = attn_size
self._state_is_tuple = state_is_tuple
self._cell = cell
self._attn_vec_size = attn_vec_size
self._input_size = input_size
self._attn_size = attn_size
self._attn_length = attn_length
self._reuse = reuse
self._attention_mech = TemporalPatternAttentionMechanism()
@property
def state_size(self):
size = (self._cell.state_size, self._attn_size,
self._attn_size * self._attn_length)
if self._state_is_tuple:
return size
else:
return sum(list(size))
@property
def output_size(self):
return self._attn_size
def call(self, inputs, state):
"""Long short-term memory cell with attention (LSTMA)."""
print("TPA Wrapper called")
if self._state_is_tuple:
state, attns, attn_states = state
else:
states = state
state = tf.slice(states, [0, 0], [-1, self._cell.state_size])
attns = tf.slice(states, [0, self._cell.state_size],
[-1, self._attn_size])
attn_states = tf.slice(
states, [0, self._cell.state_size + self._attn_size],
[-1, self._attn_size * self._attn_length])
attn_states = tf.reshape(attn_states,
[-1, self._attn_length, self._attn_size])
input_size = self._input_size
if input_size is None:
input_size = inputs.get_shape().as_list()[1]
temp_inputs = Dense(input_size, input_shape = (5760,), use_bias=True)
inputs = temp_inputs(tf.concat([inputs, attns], 1))
lstm_output, new_state = self._cell(inputs)
if self._state_is_tuple:
new_state_cat = tf.concat(nest.flatten(new_state), 1)
else:
new_state_cat = new_state
new_attns, new_attn_states = self._attention_mech(
new_state_cat, attn_states, self._attn_size, self._attn_length,
self._attn_vec_size)
with tf.variable_scope("attn_output_projection"):
temp_output = Sequential([
Dense(self._attn_size, input_shape = (2880,),
use_bias=True),
])
output = dense(tf.concat([lstm_output, new_attns], 1))
new_attn_states = tf.concat(
[new_attn_states, tf.expand_dims(output, 1)], 1)
new_attn_states = tf.reshape(new_attn_states,
[-1, self._attn_length * self._attn_size])
new_state = (new_state, new_attns, new_attn_states)
if not self._state_is_tuple:
new_state = tf.concat(list(new_state), 1)
return output, new_state
The error occurs in the line
super(TemporalPatternAttentionCellWrapper, self).__init__(_reuse=reuse) in the init function of the wrapper.
Any help would be greatly appreciated and please let me know if more information is needed.
According to the doc of LSTMCell
, it requires a mandatory units
parameters first, that is the dimensionality of the output space.
When you call its __init__()
at the error line, you need to use __init__(units, ...)
.