Search code examples
tensorflowtensorflow2.0embeddingtensorflow2.x

How to access embedding layer's variables in tensorflow?


Suppose I have the embedding layer e like this:

import tensorflow as tf
e = tf.keras.layers.Embedding(5,3)

How can I print its numpy values?


Solution

  • Thank @vald for his answer. I think e.embeddings is more pythonic and maybe efficient.

    import tensorflow as tf
    e = tf.keras.layers.Embedding(5,3)
    
    e.build(()) # You should build it before using.
    
    print(e.embeddings)
    

    >>>
    <tf.Variable 'embeddings:0' shape=(5, 3) dtype=float32, numpy=
    array([[ 0.02099125,  0.01865673,  0.03652272],
           [ 0.02714007, -0.00316695, -0.00252246],
           [-0.02411103,  0.02043924, -0.01297874],
           [ 0.00766286, -0.03511617,  0.03460207],
           [ 0.00256425, -0.03659264, -0.01796588]], dtype=float32)>