Search code examples
pythoncntk

What is the equivalent of the following tensorflow snippet in CNTK


I'm trying to implement DDPG in CNTK and came across the following code (using Tensorflow) to create the critic network:

state_input = tf.placeholder("float",[None,state_dim])
action_input = tf.placeholder("float",[None,action_dim])

W1 = self.variable([state_dim,layer1_size],state_dim)
b1 = self.variable([layer1_size],state_dim)
W2 = self.variable([layer1_size,layer2_size],layer1_size+action_dim)
W2_action = self.variable([action_dim,layer2_size],layer1_size+action_dim)
b2 = self.variable([layer2_size],layer1_size+action_dim)
W3 = tf.Variable(tf.random_uniform([layer2_size,1],-3e-3,3e-3))
b3 = tf.Variable(tf.random_uniform([1],-3e-3,3e-3))

layer1 = tf.nn.relu(tf.matmul(state_input,W1) + b1)
layer2 = tf.nn.relu(tf.matmul(layer1,W2) + tf.matmul(action_input,W2_action) + b2)
q_value_output = tf.identity(tf.matmul(layer2,W3) + b3)

where self.variable is defined as:

def variable(self,shape,f):
    return tf.Variable(tf.random_uniform(shape,-1/math.sqrt(f),1/math.sqrt(f)))

Ignoring the random initialization (I just want the structure) , I tried the following:

state_in = cntk.input(state_dim, dtype=np.float32)
action_in = cntk.input_variable(action_dim, dtype=np.float32)

W1 = cntk.parameter(shape=(state_dim, layer1_size))
b1 = cntk.parameter(shape=(layer1_size))
W2 = cntk.parameter(shape=(layer1_size, layer2_size))
W2a = cntk.parameter(shape=(action_dim, layer2_size))
b2 = cntk.parameter(shape=(layer2_size))
W3 = cntk.parameter(shape=(layer2_size, 1))
b3 = cntk.parameter(shape=(1))

l1 = cntk.relu(cntk.times(state_in, W1) + b1)
l2 = cntk.relu(cntk.times(l1, W2) + cntk.times(action_in, W2a) + b2)
Q = cntk.times(l2, W3) + b3

But, the initialization of layer2 failed with the following error (snippet):

RuntimeError: Operation 'Plus': Operand 'Output('Times24_Output_0', [#, *], [300])' has dynamic axes, that do not match the dynamic axes '[#]' of the other operands.

I would like to know what I'm doing wrong and how to accurately recreate the same model.


Solution

  • The reason is that you have define state_in as cntk.input and action_in as cntk.input_variable which by default have slightly different types: cntk.input by default creates a variable that cannot be bound to sequence data, while cntk.input_variable by default creates a variable that must be bound to sequence data (N.B. input_variable is deprecated and some IDEs like PyCharm will show this with a strikethrough, please use cntk.input() or cntk.sequence.input()).

    The error says that the plus operation cannot add cntk.times(l1, W2) which has dynamic axes [#] (meaning the minibatch dimension) with cntk.times(action_in, W2a) which has dynamic axes [#, *] (meaning the minibatch and sequence dimension).

    The simplest fix is to declare action_in = cntk.input(action_dim, dtype=np.float32) which makes the rest of the operations typecheck.