I'm trying to train and run an image classification model on Colab, using TPU. No pytorch.
I know that TPU works only with files from GCS buckets so I load the dataset from a bucket and I commented also the checkpoint and logging functions, to not have this type of errors. I just want to see it trains without errors on TPU.
On CPU and GPU the code works, but the problem appears when I use with strategy.scope():
before creating the model.
This is the function that gives me problems while I train the model:
def train_step(self, images, labels):
with tf.GradientTape() as tape:
predictionProbs = self(images, training=True)
loss = self.loss_fn(labels, predictionProbs)
grads = tape.gradient(loss, self.trainable_weights)
predictionLabels = tf.squeeze(tf.cast(predictionProbs > PROB_THRESHOLD_POSITIVE, tf.float32), axis=1)
acc = tf.reduce_mean(tf.cast(predictionLabels == labels, tf.float32))
self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) # here is the problem
return loss, acc
And this is the error I encounter:
RuntimeError: `apply_gradients() cannot be called in cross-replica context. Use `tf.distribute.Strategy.run` to enter replica context.
I've looked upon https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy and I think here is the solution but I've never done this before and I don't know from where I can start.
Can somebody, please, give me an advice on this problem?
You have to call this procedure with strategy.run():
strategy.run(train_step, args=(images, labels))