Search code examples
tensorflowneural-networkreinforcement-learningopenai-gym

What would be the output from tensorflow dense layer if we assign itself as input and output while making a neural network?


I have been going through the implementation of neural network in openAI code for any Vanilla Policy Gradient (As a matter of fact, this part is used nearly everywhere). The code looks something like this :

def mlp_categorical_policy(x, a, hidden_sizes, activation, output_activation, action_space):
    act_dim = action_space.n
    logits = mlp(x, list(hidden_sizes) + [act_dim], activation, None)
    logp_all = tf.nn.log_softmax(logits)
    pi = tf.squeeze(tf.random.categorical(logits, 1), axis=1)
    logp = tf.reduce_sum(tf.one_hot(a, depth=act_dim) * logp_all, axis=1)
    logp_pi = tf.reduce_sum(tf.one_hot(pi, depth=act_dim) * logp_all, axis=1)
    return pi, logp, logp_pi

and this multi-layered perceptron network is defined as follows :

def mlp(x, hidden_sizes=(32,), activation=tf.tanh, output_activation=None):
    for h in hidden_sizes[:-1]:
        x = tf.layers.dense(inputs=x, units=h, activation=activation)
    return tf.layers.dense(inputs=x, units=hidden_sizes[-1], activation=output_activation)

My question is what is the return from this mlp function? I mean the structure or shape. Is it an N-dimentional tensor? If so, how is it given as an input to tf.random_categorical? If not, and its just has the shape [hidden_layer2, output], then what happened to the other layers? As per their website description about random_categorical it only takes a 2-D input. The complete code of openAI's VPG algorithm can be found here. The mlp is implemented here. I would be highly grateful if someone would just tell me what this mlp_categorical_policy() is doing?

Note: The hidden size is [64, 64], the action dimension is 3

Thanks and cheers


Solution

  • Note that this is a discrete action space - there are action_space.n different possible actions at every step, and the agent chooses one.

    To do this the MLP is returning the logits (which are a function of the probabilities) of the different actions. This is specified in the code by + [act_dim] which is appending count of the action_space as the final MLP layer. Note that the last layer of an MLP is the output layer. The input layer is not specified in tensorflow, it is inferred from the inputs.

    tf.random.categorical takes the logits and samples a policy action pi from them, which is returned as a number.

    mlp_categorical_policy also returns logp, the log probability of the action a (used to assign credit), and logp_pi, the log probability of the policy action pi.


    It seems your question is more about the return from the mlp.

    The mlp creates a series of fully connected layers in a loop. In each iteration of the loop, the mlp is creating a new layer using the previous layer x as an input and assigning it's output to overwrite x, with this line x = tf.layers.dense(inputs=x, units=h, activation=activation).

    So the output is not the same as the input, on each iteration x is overwritten with the value of the new layer. This is the same kind of coding trick as x = x + 1, which increments x by 1. This effectively chains the layers together.

    The output of tf.layers.dense is a tensor of size [:,h] where : is the batch dimension (and can usually be ignored). The creation of the last layer happens outisde the loop, it can be seen that the number of nodes in this layer is act_dim (so shape is [:,3]). You can check the shape by doing this:

    import tensorflow.compat.v1 as tf
    import numpy as np
    
    def mlp(x, hidden_sizes=(32,), activation=tf.tanh, output_activation=None):
        for h in hidden_sizes[:-1]:
            x = tf.layers.dense(x, units=h, activation=activation)
        return tf.layers.dense(x, units=hidden_sizes[-1], activation=output_activation)
    
    obs = np.array([[1.0,2.0]])
    logits = mlp(obs, [64, 64, 3], tf.nn.relu, None)
    print(logits.shape)
    

    result: TensorShape([1, 3])

    Note that the observation in this case is [1.,2.], it is nested inside a batch of size 1.