Search code examples
pythontensorflowtensorflow-datasets

Passing a tf.dataset as the keys of a dictionary


I have a problem related to passing the elements of a tf.dictionary as the keys of a dictionary. I have reduced it to the following minimal example:

def example(x,d):
   w=tf.vectorized_map(lambda y: d[y],tf.cast(x, tf.string))
   return w


dataset = tf.data.Dataset.from_tensor_slices([['a','d','s'],['b','e','a'],['c','f','d']])
d={'a':1,'b':2,'c':3,'d':4,'e':6,'f':5,'s':1}
dataset.map(lambda x: example(x,d))

I get the error:

TypeError: Failed to convert object of type <class 'tensorflow.python.util.object_identity.Reference'> to Tensor. Contents: <Reference wrapping <tf.Tensor 'args_0:0' shape=(3,) dtype=string>>. Consider casting elements to a supported type.

I have tried to solve it removing the tf.cast(x, tf.string) and changing tf.vectorized_map by tf.map_fn. In both cases I get the same error.

How can I run the code?


Solution

  • You can use tf.lookup.StaticHashTable to achieve this.

    import tensorflow as tf
    keys_tensor = tf.constant(['a', 'b', 'c', 'd', 'e', 'f', 's'])
    vals_tensor = tf.constant([1, 2, 3, 4, 6, 5, 1])
    table = tf.lookup.StaticHashTable(
        tf.lookup.KeyValueTensorInitializer(keys_tensor, vals_tensor),
        default_value=-1)
    
    dataset = tf.data.Dataset.from_tensor_slices([['a','d','s'],['b','e','a'],['c','f','d']])
    ds=dataset.map(lambda x:table[x])
    
    for x in ds:
      print(x)
    '''
    tf.Tensor([1 4 1], shape=(3,), dtype=int32)
    tf.Tensor([2 6 1], shape=(3,), dtype=int32)
    tf.Tensor([3 5 4], shape=(3,), dtype=int32)
    '''