Search code examples
pythontensorflowtensorflow2.0tensorflow-servingsentence-similarity

Tensorflow session error in universal sentence encoder


I have the following code for the universal sentence encoder and it gives the following error(check below) once i load the model into a flask api and try hitting it:

'''

import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()

module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
model_2 = hub.load(module_url)
print ("module %s loaded" % module_url)

def embed(input):
    return model_2(input)


def universalModel(messages):
    accuracy = []
    similarity_input_placeholder = tf.placeholder(tf.string, shape=(None))
    similarity_message_encodings = embed(similarity_input_placeholder)
    with tf.Session() as session:
        session.run(tf.global_variables_initializer())
        session.run(tf.tables_initializer())
        message_embeddings_ = session.run(similarity_message_encodings, feed_dict={similarity_input_placeholder: messages})

        corr = np.inner(message_embeddings_, message_embeddings_)
        accuracy.append(corr[0,1])
    # print(corr[0,1])
    return "%.2f" % accuracy[0]

'''

The following error it gives while using the model into the flask api: tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph is invalid, contains a cycle with 1 nodes, including: StatefulPartitionedCall Although this code runs without any error the in colab notebook.

I am using tensorflow version 2.2.0.


Solution

  • import tensorflow.compat.v1 as tf
    tf.disable_v2_behavior()
    

    These two lines are intended to make tensorflow 2.x to tensorflow 1.x.

    For Tensorflow 1.x, this is common issue while serving with flask, django, etc. You have to define a graph and session for inference,

    import tensorflow as tf import tensorflow_hub as hub

    # Create graph and finalize (finalizing optional but recommended).
    g = tf.Graph()
    with g.as_default():
      # We will be feeding 1D tensors of text into the graph.
      text_input = tf.placeholder(dtype=tf.string, shape=[None])
      embed = hub.Module("https://tfhub.dev/google/universal-sentence-encoder/2")
      embedded_text = embed(text_input)
      init_op = tf.group([tf.global_variables_initializer(), tf.tables_initializer()])
    g.finalize()
    
    # Create session and initialize.
    session = tf.Session(graph=g)
    session.run(init_op)
    

    The input request can be handled through

    result = session.run(embedded_text, feed_dict={text_input: ["Hello world"]})
    

    For details https://www.tensorflow.org/hub/common_issues

    For tensorflow 2.x session and graph is not required.

    import tensorflow as tf
    module_url = "https://tfhub.dev/google/universal-sentence-encoder-large/5"
    model_2 = hub.load(module_url)
    print ("module %s loaded" % module_url)
    
    def embed(input):
        return model_2(input)
    #pass messages as list
    def universalModel(messages):
        accuracy = []
        message_embeddings_= embed(messages)
        corr = np.inner(message_embeddings_, message_embeddings_)
        accuracy.append(corr[0,1])
        # print(corr[0,1])
        return "%.2f" % accuracy[0]