Search code examples
pythontensorflowtensorflow-datasets

Converting a tf.Tensor from one int value to another via lookup


I have the following labels: [15, 76, 34]. I am trying to map them to be [0, 1, 2] inside of a tf.data.Dataset using the map function.

So I need a function that can do the following:

def relabel(label: tf.Tensor) -> tf.Tensor:
    # TODO: convert 15 --> 0, 76 --> 1, 34 --> 2
    return new_label

dataset: tf.data.Dataset
dataset = dataset.map(lambda x, y: x, relabel(y))

I am having a tough time working with tf.Tensor, can anyone complete this implementation?


Solution

  • You can create a lookup table that assigns the old labels to new labels:

    label_tensor = tf.constant([15, 76, 34], tf.int32)
    new_label_tensor = tf.constant([0, 1, 2])
    table = tf.lookup.StaticHashTable(
        tf.lookup.KeyValueTensorInitializer(label_tensor, new_label_tensor, key_dtype=tf.int32,value_dtype=tf.int32), -1)
    

    checking for inputs:

    X = tf.constant([0.1, 0.2, 0.3], dtype=tf.float32)
    Y = tf.constant([15, 76, 34], dtype=tf.int32)
    dataset = tf.data.Dataset.from_tensor_slices((X, Y))
    

    relabeling can be done by,

    def relabel(x, y):
        return x, table.lookup(y)
    dataset = dataset.map(relabel)
    

    Outputs,

     for x, y in dataset:
        print(x.numpy(), y.numpy())
     #outputs
     0.1 0
     0.2 1
     0.3 2