Search code examples
pythontensorflowtensorflow2.0distributed-computingtpu

Train model on Colab TPU with distributed strategy


I'm trying to train and run an image classification model on Colab, using TPU. No pytorch.

I know that TPU works only with files from GCS buckets so I load the dataset from a bucket and I commented also the checkpoint and logging functions, to not have this type of errors. I just want to see it trains without errors on TPU.

On CPU and GPU the code works, but the problem appears when I use with strategy.scope(): before creating the model. This is the function that gives me problems while I train the model:

def train_step(self, images, labels):
    with tf.GradientTape() as tape:
        predictionProbs = self(images, training=True)
        loss = self.loss_fn(labels, predictionProbs)

    grads = tape.gradient(loss, self.trainable_weights)

    predictionLabels = tf.squeeze(tf.cast(predictionProbs > PROB_THRESHOLD_POSITIVE, tf.float32), axis=1)
    acc = tf.reduce_mean(tf.cast(predictionLabels == labels, tf.float32))

    self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) # here is the problem

    return loss, acc

And this is the error I encounter:

RuntimeError: `apply_gradients() cannot be called in cross-replica context. Use `tf.distribute.Strategy.run` to enter replica context.

I've looked upon https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy and I think here is the solution but I've never done this before and I don't know from where I can start.

Can somebody, please, give me an advice on this problem?


Solution

  • You have to call this procedure with strategy.run():

    strategy.run(train_step, args=(images, labels))