Search code examples
pythontensorflowdeep-learningneural-networkjax

Adding a new layer to a stax.serial object


I'd like to "convert" the following tensorflow code in jax:

def mlp(L, n_list, activation, Cb, Cw):
    model = tf.keras.Sequential()

    kernel_initializers_list = []
    kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[0])))
    for l in range(1, L): 
        kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[l])))
    kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[L])))
    bias_initializer = tf.keras.initializers.RandomNormal(stddev=math.sqrt(Cb))


    model.add(tf.keras.layers.Dense(n_list[1], input_shape=[n_list[0]], use_bias = True, kernel_initializer = kernel_initializers_list[0],
          bias_initializer = bias_initializer))
    for l in range(1, L): 
        model.add(tf.keras.layers.Dense(n_list[l+1], activation=activation, use_bias = True, kernel_initializer = kernel_initializers_list[l],
              bias_initializer = bias_initializer))
    model.add(tf.keras.layers.Dense(n_list[L+1], use_bias = True, kernel_initializer = kernel_initializers_list[L],
              bias_initializer = bias_initializer))
    print(model.summary())
    return model

In jax can I add a stax.Dense to the thing I get calling stax.serial() with something equivalent to tensorflow's model.add()? How can I do it?


Solution

  • Yes, you can.

    #Create new model by jax
    net_init, net_apply = stax.serial(
        Conv(32, (3, 3), padding='SAME'),
        Relu,
        Conv(64, (3, 3), padding='SAME'),
        Relu,
        Conv(128, (3, 3), padding='SAME'),
        Relu,
        Conv(256, (3, 3), padding='SAME'),
        Relu,
        MaxPool((2, 2)),
        Flatten,
        Dense(128),
        Relu,
        Dense(10),
        LogSoftmax,
    )
    
    net_init(random.PRNGKey(111), input_shape=(-1, 32, 32, 3))    
    
    #Feedfoward
    inputs, targets = batch_data
    net_apply(params, inputs)
    

    This is my reference to help you.