I would like to keep track of the gradients over tensorboard. However, since session run statements are not a thing anymore and the write_grads argument of tf.keras.callbacks.TensorBoard is deprecated, I would like to know how to keep track of gradients during training with Keras or tensorflow 2.0.
My current approach is to create a new callback class for this purpose, but without success. Maybe someone else knows how to accomplish this kind of advanced stuff.
The code created for testing is shown below, but runs into errors independently of printing a gradient value to console or tensorboard.
import tensorflow as tf
from tensorflow.python.keras import backend as K
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu', name='dense128'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax', name='dense10')
])
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
class GradientCallback(tf.keras.callbacks.Callback):
console = True
def on_epoch_end(self, epoch, logs=None):
weights = [w for w in self.model.trainable_weights if 'dense' in w.name and 'bias' in w.name]
loss = self.model.total_loss
optimizer = self.model.optimizer
gradients = optimizer.get_gradients(loss, weights)
for t in gradients:
if self.console:
print('Tensor: {}'.format(t.name))
print('{}\n'.format(K.get_value(t)[:10]))
else:
tf.summary.histogram(t.name, data=t)
file_writer = tf.summary.create_file_writer("./metrics")
file_writer.set_as_default()
# write_grads has been removed
tensorboard_cb = tf.keras.callbacks.TensorBoard(histogram_freq=1, write_grads=True)
gradient_cb = GradientCallback()
model.fit(x_train, y_train, epochs=5, callbacks=[gradient_cb, tensorboard_cb])
tf.Tensor
as a Python bool
is not allowed. Use if t is not None:
instead of if t:
to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the
value of a tensor.To compute the gradients of the loss against the weights, use
with tf.GradientTape() as tape:
loss = model(model.trainable_weights)
tape.gradient(loss, model.trainable_weights)
This is (arguably poorly) documented on GradientTape.
We do not need to tape.watch
the variable because trainable parameters are watched by default.
As a function, it can be written as
def gradient(model, x):
x_tensor = tf.convert_to_tensor(x, dtype=tf.float32)
with tf.GradientTape() as t:
t.watch(x_tensor)
loss = model(x_tensor)
return t.gradient(loss, x_tensor).numpy()