Search code examples
tensorflowclassificationtext-classificationtraining-data

Can I retrain an old model with new data using TensorFlow?


I am new to TensorFlow and I am just trying to see if my idea is even possible.

I have trained a model with multi class classifier. Now I can classify a sentence in input, but I would like to change the result of CNN, for example, to improve the score of classification or change the classification.

I want to try to train just a single sentence with its class on a trained model, is this possible?


Solution

  • If I understand your question correctly, you are trying to reload a previously trained model either to run it through further iterations, test it on a new sentence, or fine tune the model a bit. If this is the case, yes you can do this. Look into saving and restoring models (https://www.tensorflow.org/api_guides/python/state_ops#Saving_and_Restoring_Variables).

    To give you a rough outline, when you initially train your model, after setting up the network architecture, set up a saver:

    trainable_var = tf.trainable_variables()
    sess = tf.Session()
    saver = tf.train.Saver()
    sess.run(tf.global_variables_initializer
    
    # Run/train your model until some completion criteria is reached
    #....
    #....
    
    saver.save(sess, 'model.ckpt')
    

    Now, to reload your model:

    saver = tf.train.import_meta_graph('model.ckpt.meta')
    saver.restore('model.ckpt')
    #Note: if you have already defined all variables before restoring the model, import_meta_graph is not necessary
    

    This will give you access to all the trained variables and you can now feed in whatever new sentence you have. Hope this helps.