Search code examples
tensorflowtensorflow-estimator

How to initialize embeddings layer within Estimator API


I'm trying to use existing embeddings within tensorflow model, the size of embedding is greater than 2Gb and this makes my original try of doing this unsuccessful:

embedding_var = tf.get_variable(
        "embeddings", 
        shape=GLOVE_MATRIX.shape, 
        initializer=tf.constant_initializer(np.array(GLOVE_MATRIX))
)

Which gave me this error:

 Cannot create a tensor proto whose content is larger than 2GB.

I'm using AWS SageMaker, which based on the Estimator API, and the actual running of the graph in session happens behind the scene, so I'm not sure how to initialize some placeholders for embedding given that. Would be helpful if someone will be able to share the way how to do such initialization in term of EstimatorAPI.


Solution

  • If you specify the initializer argument to tf.get_variable(), the initial value GLOVE_MATRIX will be stored in the graph and go over 2Gb. A good answer explains how to load embeddings in general.


    Here is a first example where we use the initializer and the graph def is around 4Mb since it stores the (1000, 1000) matrix in it.

    size = 1000
    initial_value = np.random.randn(size, size)
    x = tf.get_variable("x", [size, size], initializer=tf.constant_initializer(initial_value))
    
    sess = tf.Session()
    sess.run(x.initializer)
    
    assert np.allclose(sess.run(x), initial_value)
    
    graph = tf.get_default_graph()
    print(graph.as_graph_def().ByteSize())  # should be 4000394
    

    Here is a better version where we don't store it:

    size = 1000
    initial_value = np.random.randn(size, size)
    x = tf.get_variable("x", [size, size])
    
    sess = tf.Session()
    sess.run(x.initializer, {x.initial_value: initial_value})
    
    assert np.allclose(sess.run(x), initial_value)
    
    graph = tf.get_default_graph()
    print(graph.as_graph_def().ByteSize())  # should be 1203
    

    In Estimators

    For Estimators, we don't have direct access to the Session. A way to initialize the embedding can be to use tf.train.Scaffold. You can pass it an argument init_fn in which you initialize the embedding variable, without saving the actual value in the graph def.

    def model_fn(features, labels, mode):
        size = 10
        initial_value = np.random.randn(size, size).astype(np.float32)
        x = tf.get_variable("x", [size, size])
    
        def init_fn(scaffold, sess):
            sess.run(x.initializer, {x.initial_value: initial_value})
        scaffold = tf.train.Scaffold(init_fn=init_fn)
    
        loss = ...
        train_op = ...
    
        return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, scaffold=scaffold)
    

    A good point about using the built-in Scaffold is that it will only initialize the embedding when you first call train_input_fn. For future calls, it will not run again init_fn.