Search code examples
pythontensorflowneural-networktext-classificationtflearn

How to import(restore) Neural network model built by tflearn from files


I am referring to this tutorial on text classification and built a custom training set for a text classification.

I am saving the model with below code.

# Define model and setup tensorboard
model = tflearn.DNN(net, tensorboard_dir='tflearn_logs')
# Start training (apply gradient descent algorithm)
model.fit(train_x, train_y, n_epoch=1000, batch_size=8, show_metric=True)
model.save('model.tflearn')

This generates below files.

model.tflearn.data-00000-of-00001
model.tflearn.index
model.tflearn.meta
tflearn_logs folder

I want to use the model built in different iteration for testing purpose.

I tried ,

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('model.tflearn.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))

but I get;

KeyError: "The name 'adam' refers to an Operation not in the graph." error

I know from documentation that tflearn.DNN(network).load('file_name') loads a model , but we need to create and pass the network instance, to build a network we again go through same code from scratch which takes time since it will do training which I want to avoid.

Code for building network

net = tflearn.input_data(shape=[None, len(train_x[0])])
net = tflearn.fully_connected(net, 8)
net = tflearn.fully_connected(net, 8)
net = tflearn.fully_connected(net, len(train_y[0]), activation='softmax')
net = tflearn.regression(net)

tflearn.input_data has shape input as mandatory , so we would again need training data to be fed again.So it causes rebuilding model. I checked the documentation , could not find what I need (2-3 lines of code which would import build neural network model to save retraining time.

Please let me know if you guys know solution for this.

Similar question but its not duplicate

  • OP was facing issue while building neural net during building tree , while I am facing issue importing build model.
  • Tutorial mentioned in the answer does not have tflearn NN model import

Solution

  • I was able to restore the saved model with below code.

    tflearn can restore model from saved log and model files.

    Create dummy neural net of same size as saved model

    Note : You may need to keep track of previously saved model's weights (size of input training and corresponding classes)

    net = tflearn.input_data(shape=[None, train_x[0]])
    net = tflearn.fully_connected(net, 8, restore=False)
    net = tflearn.fully_connected(net, 8, restore=False)
    net = tflearn.fully_connected(net, train_y[0], activation='softmax', restore=False)
    dnn = tflearn.DNN(net, tensorboard_dir='tflearn_logs')
    

    Load the saved model on to DNN

    model = dnn.load('./model.tflearn')
    
    Use the loaded model for predictions
    test_data = ###converted data 
    model.predict(test_data)