I'd like to "convert" the following tensorflow code in jax:
def mlp(L, n_list, activation, Cb, Cw):
model = tf.keras.Sequential()
kernel_initializers_list = []
kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[0])))
for l in range(1, L):
kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[l])))
kernel_initializers_list.append(tf.keras.initializers.RandomNormal(0, math.sqrt(Cw/n_list[L])))
bias_initializer = tf.keras.initializers.RandomNormal(stddev=math.sqrt(Cb))
model.add(tf.keras.layers.Dense(n_list[1], input_shape=[n_list[0]], use_bias = True, kernel_initializer = kernel_initializers_list[0],
bias_initializer = bias_initializer))
for l in range(1, L):
model.add(tf.keras.layers.Dense(n_list[l+1], activation=activation, use_bias = True, kernel_initializer = kernel_initializers_list[l],
bias_initializer = bias_initializer))
model.add(tf.keras.layers.Dense(n_list[L+1], use_bias = True, kernel_initializer = kernel_initializers_list[L],
bias_initializer = bias_initializer))
print(model.summary())
return model
In jax can I add a stax.Dense
to the thing I get calling stax.serial()
with something equivalent to tensorflow's model.add()
? How can I do it?
Yes, you can.
#Create new model by jax
net_init, net_apply = stax.serial(
Conv(32, (3, 3), padding='SAME'),
Relu,
Conv(64, (3, 3), padding='SAME'),
Relu,
Conv(128, (3, 3), padding='SAME'),
Relu,
Conv(256, (3, 3), padding='SAME'),
Relu,
MaxPool((2, 2)),
Flatten,
Dense(128),
Relu,
Dense(10),
LogSoftmax,
)
net_init(random.PRNGKey(111), input_shape=(-1, 32, 32, 3))
#Feedfoward
inputs, targets = batch_data
net_apply(params, inputs)
This is my reference to help you.