Search code examples
pythontensorflowgoogle-colaboratorycheckpointmachine-translation

Saving tensorflow encoder, decoder and attention


Start training a simple NMT (neural machine translator) with attention using encoder and decoder, Training was on Colab,

encoder = Encoder(vocab_inp_size, embedding_dim, units, BATCH_SIZE)
decoder = Decoder(vocab_tar_size, embedding_dim, units, BATCH_SIZE)

Then use checkpoints to save the model,

# On loacl machine dir changed to 'training_checkpoints/' to fit the loaction
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

And save during traing using

checkpoint.save(file_prefix = checkpoint_prefix)

After training restore checkpoints works fine on Colab, and even when save the whole checkpoint folder on Google drive and restore them again, but when trying to restore them on my local machine its return different and rubbish results, Start checkpoint before training using

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

Colab notebook output:

Input: <start> يلعبون الكرة <end>
Predicted translation: he played soccer . <end> 

Local machine output:

Input: <start> يلعبون الكرة <end>
Predicted translation: take either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either either

Colab tensorflow version: 1.13.0-rc1

Local machine tensorflow version: 1.12.0

How to save the model without facing this issue, knowing that this issue is due to the different versions of tensorflow?

An additional link for NMT notebook Neural Machine Translation with Attention


Solution

  • TF only makes forward-compatibility guarantees: https://www.tensorflow.org/guide/version_compat#compatibility_of_graphs_and_checkpoints It's not surprising that 1.13 saves a file that 1.12 cannot restore. Upgrade your local machine's tensorflow?