Search code examples
pythontensorflowkerasdeep-learningreinforcement-learning

TypeError: __init__() missing 1 required positional argument: 'units' in LSTMCell


I'm trying to implement Temporal attention in a Reinforcement Learning problem using Stable baselines however, I keep getting the mentioned error in the customer policy. I am using TensorFlow version 1.14. While using an LSTMCell along with RNN class from TensorFlow in my policy.py, I am also initializing a wrapper for attention but I keep getting the following error.

Traceback (most recent call last):
  File "run.py", line 60, in <module>
    trainedModel = model_training(featureMatrix, config['env_name'], config['number_of_cpus'], config['total_training_timesteps'], config['policy'])
  File "/code/src/util/utils.py", line 88, in model_training
    trained_model = trained_model.train()
  File "/code/src/util/model/model_training.py", line 103, in train
    tensorboard_log=self.tensorboard_path).learn(total_timesteps=self.total_training_timesteps, callback=self.callback)
  File "/venv/lib/python3.7/site-packages/stable_baselines/acktr/acktr.py", line 119, in __init__
    self.setup_model()
  File "/venv/lib/python3.7/site-packages/stable_baselines/acktr/acktr.py", line 148, in setup_model
    1, n_batch_step, reuse=False, **self.policy_kwargs)
  File "/code/src/util/policy/policy.py", line 97, in __init__
    rnn = tf.keras.layers.RNN(self._build_rnn_cell())
  File "/code/src/util/policy/policy.py", line 165, in _build_rnn_cell
    return tf.keras.layers.StackedRNNCells([self._build_single_cell() for _ in range(3)])
  File "/code/src/util/policy/policy.py", line 165, in <listcomp>
    return tf.keras.layers.StackedRNNCells([self._build_single_cell() for _ in range(3)])
  File "/code/src/util/policy/policy.py", line 158, in _build_single_cell
    128,
  File "/code/src/util/policy/attention_wrapper.py", line 123, in __init__
    super(TemporalPatternAttentionCellWrapper, self).__init__(_reuse=reuse)
TypeError: __init__() missing 1 required positional argument: 'units'

My policy.py is as follows:

class CustomPolicy(ActorCriticPolicy):
    def __init__(self, sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=False, **kwargs):
        super(CustomPolicy, self).__init__(sess, ob_space, ac_space, n_env, n_steps, n_batch, reuse=reuse, scale=True)

        with tf.variable_scope("model", reuse=reuse):
            rnn = tf.keras.layers.RNN(self._build_rnn_cell())

            feature_layer = rnn(self.processed_obs)

            pi_layers = Sequential([
                Dense(128, input_shape = (256,), 
                      kernel_regularizer=regularizers.l2(0.01), activity_regularizer=regularizers.l1(0.01)),
                Activation('relu'),
                Dense(128, kernel_regularizer=regularizers.l2(0.01), activity_regularizer=regularizers.l1(0.01))
            ])

            pi_latent = pi_layers(feature_layer)

            vf_layers = Sequential([
                Dense(32, input_shape = (256,),
                      kernel_regularizer=regularizers.l2(0.01), activity_regularizer=regularizers.l1(0.01)),
                Activation('relu'),
                Dense(32, kernel_regularizer=regularizers.l2(0.01), activity_regularizer=regularizers.l1(0.01))
            ])

            vf_latent = vf_layers(feature_layer)            
            temp_value_fn = Dense(1, input_shape=(32,))
            value_fn = temp_value_fn(vf_latent)

            self._proba_distribution, self._policy, self.q_value = \
                self.pdtype.proba_distribution_from_latent(pi_latent, vf_latent, init_scale=0.01)

        self._value_fn = value_fn
        self._setup_init()


    def step(self, obs, state=None, mask=None, deterministic=False):
        if deterministic:
            action, value, neglogp = self.sess.run([self.deterministic_action, self.value_flat, self.neglogp],
                                                   {self.obs_ph: obs})
        else:
            action, value, neglogp = self.sess.run([self.action, self.value_flat, self.neglogp],
                                                   {self.obs_ph: obs})
        return action, value, self.initial_state, neglogp

    def proba_step(self, obs, state=None, mask=None):
        return self.sess.run(self.policy_proba, {self.obs_ph: obs})

    def value(self, obs, state=None, mask=None):
        return self.sess.run(self.value_flat, {self.obs_ph: obs})

    def _build_single_cell(self):
        cell = tf.keras.layers.LSTMCell(256)
        cell = TemporalPatternAttentionCellWrapper(
            cell,
            128,
        )
        return cell

    def _build_rnn_cell(self):
        return tf.keras.layers.StackedRNNCells([self._build_single_cell() for _ in range(3)])

and my Attention Wrapper is as follows:

class TemporalPatternAttentionCellWrapper(tf.keras.layers.LSTMCell):
    def __init__(self,
                 cell,
                 attn_length,
                 units=256,
                 attn_size=None,
                 attn_vec_size=None,
                 input_size=None,
                 state_is_tuple=True,
                 reuse=None):
        """Create a cell with attention.
        Args:
            cell: an RNNCell, an attention is added to it.
            attn_length: integer, the size of an attention window.
            attn_size: integer, the size of an attention vector. Equal to
                cell.output_size by default.
            attn_vec_size: integer, the number of convolutional features
                calculated on attention state and a size of the hidden layer
                built from base cell state. Equal attn_size to by default.
            input_size: integer, the size of a hidden linear layer, built from
                inputs and attention. Derived from the input tensor by default.
            state_is_tuple: If True, accepted and returned states are n-tuples,
                where `n = len(cells)`. By default (False), the states are all
                concatenated along the column axis.
            reuse: (optional) Python boolean describing whether to reuse
                variables in an existing scope. If not `True`, and the existing
                scope already has the given variables, an error is raised.
        Raises:
            TypeError: if cell is not an RNNCell.
            ValueError: if cell returns a state tuple but the flag
                `state_is_tuple` is `False` or if attn_length is zero or less.
        """
        super(TemporalPatternAttentionCellWrapper, self).__init__(_reuse=reuse)
        if nest.is_sequence(cell.state_size) and not state_is_tuple:
            raise ValueError("Cell returns tuple of states, but the flag "
                             "state_is_tuple is not set. State size is: %s" %
                             str(cell.state_size))
        if attn_length <= 0:
            raise ValueError("attn_length should be greater than zero, got %s"
                             % str(attn_length))
        if not state_is_tuple:
            logging.warn(
                "%s: Using a concatenated state is slower and will soon be "
                "deprecated.    Use state_is_tuple=True.", self)
        if attn_size is None:
            attn_size = 2880
        if attn_vec_size is None:
            attn_vec_size = attn_size
        self._state_is_tuple = state_is_tuple
        self._cell = cell
        self._attn_vec_size = attn_vec_size
        self._input_size = input_size
        self._attn_size = attn_size
        self._attn_length = attn_length
        self._reuse = reuse
        self._attention_mech = TemporalPatternAttentionMechanism()


    @property
    def state_size(self):
        size = (self._cell.state_size, self._attn_size,
                self._attn_size * self._attn_length)
        if self._state_is_tuple:
            return size
        else:
            return sum(list(size))

    @property
    def output_size(self):
        return self._attn_size

    def call(self, inputs, state):
        """Long short-term memory cell with attention (LSTMA)."""
        print("TPA Wrapper called")
        if self._state_is_tuple:
            state, attns, attn_states = state
        else:
            states = state
            state = tf.slice(states, [0, 0], [-1, self._cell.state_size])
            attns = tf.slice(states, [0, self._cell.state_size],
                             [-1, self._attn_size])
            attn_states = tf.slice(
                states, [0, self._cell.state_size + self._attn_size],
                [-1, self._attn_size * self._attn_length])
        attn_states = tf.reshape(attn_states,
                                 [-1, self._attn_length, self._attn_size])
        input_size = self._input_size
        if input_size is None:
            input_size = inputs.get_shape().as_list()[1]

        temp_inputs = Dense(input_size, input_shape = (5760,), use_bias=True)

        inputs = temp_inputs(tf.concat([inputs, attns], 1))

        lstm_output, new_state = self._cell(inputs)

        if self._state_is_tuple:
            new_state_cat = tf.concat(nest.flatten(new_state), 1)
        else:
            new_state_cat = new_state
        new_attns, new_attn_states = self._attention_mech(
            new_state_cat, attn_states, self._attn_size, self._attn_length,
            self._attn_vec_size)

        with tf.variable_scope("attn_output_projection"):
            temp_output = Sequential([
                Dense(self._attn_size, input_shape = (2880,),
                      use_bias=True),
            ])

            output = dense(tf.concat([lstm_output, new_attns], 1))

        new_attn_states = tf.concat(
            [new_attn_states, tf.expand_dims(output, 1)], 1)
        new_attn_states = tf.reshape(new_attn_states,
                                     [-1, self._attn_length * self._attn_size])
        new_state = (new_state, new_attns, new_attn_states)
        if not self._state_is_tuple:
            new_state = tf.concat(list(new_state), 1)

        return output, new_state

The error occurs in the line

super(TemporalPatternAttentionCellWrapper, self).__init__(_reuse=reuse) in the init function of the wrapper.

Any help would be greatly appreciated and please let me know if more information is needed.


Solution

  • According to the doc of LSTMCell, it requires a mandatory units parameters first, that is the dimensionality of the output space.

    When you call its __init__() at the error line, you need to use __init__(units, ...).