Search code examples
pythontensorflowtensorflow-datasets

Add a unique id to every y of a tensorflow dataset


I am training an auto-encoder using MNIST and tensorflow.

(ds_train_original, ds_test_original), ds_info = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

batch_size = 2014
def normalize_img(image, label):
    """Normalizes images: `uint8` -> `float32`."""
    return tf.cast(image, tf.float32) / 255.0, label

I would like to have my x be the image and my y to be a tuple with the same image associated with a unique index value (a int/float). The reason is that I want to pass that id to my loss function. I would like to not manually iterate and create a new Dataset but if that's the only solution'll go with it.

I have tried multiple things such as using the map method with a global var:

lab = -1
def add_label(x, _):
    global lab
    lab += 1
    return (x, (x, [lab]))

ds_train_original = ds_train_original.map(normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train_original.cache()
ds_train = ds_train.shuffle(ds_info.splits["train"].num_examples)
# replace labels by image itself and unique id for decoder/encoder
ds_train = ds_train.map(add_label)

However, this return 0 as the index for all inputs instead of a unique value.

I have also tried to manually add a label by enumerating the dataset, but it is taking forever that way.

Is there an efficient way to modify a TensorFlow dataset when the function applied to it is not uniform on the dataset.


Solution

  • So what I would do in this case would be to just use the ref() method of the target tensors. Every tensor already has a unique identifier and this method allows you to access it.

    You can try:

    import tensorflow as tf
    import tensorflow_datasets as tfds
    import numpy as np
    
    (ds_train_original, ds_test_original), ds_info = tfds.load(
        "mnist",
        split=["train", "test"],
        shuffle_files=True,
        as_supervised=True,
        with_info=True,
    )
    
    # save the references to your tensors
    ids = np.array([y.ref() for _, y in ds_train_original])
    
    # you can check that they are all unique
    print(ids.shape, np.unique(ids).shape)
    
    # find the 42th tensor using the deref()
    t = ids[42].deref()
    print(t)
    
    # use np.where to find the index of a tensor reference
    np.where( ids == t.ref())[0]