Search code examples
tensorflowtensorflow-datasets

ref() of tensor not equal in dataset. Why?


I am very confused by the following behavior. Take this program:

import tensorflow_datasets as tfds

# %% Train dataset
(ds_train_original, ds_test_original), ds_info = tfds.load(
    "mnist",
    split=["train", "test"],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)

iterator = iter(ds_train_original)
el = iterator.get_next()[0]
el[0].ref() == el[0].ref()   # <- this should be True

The last line IMO should return True. However, this is False. I cannot understand why.

According to the ref documentation:

Returns a hashable reference object to this Tensor. The primary use case for this API is to put tensors in a set/dictionary.

My understanding is that you should be able to use the ref() to check for equality between Tensor. Here the problem doesn't happen anymore once I have extracted the ref. For example, this is True:

a_ref = el[0].ref()
a_deref = a_ref.deref()
another_ref = a_deref.ref()
a_ref == another_ref

So the "problem" seems confined to extracting the ref() from iterator.

Can anybody explain to me what is happening and why el[0].ref() == el[0].ref() is False?


Solution

  • After posting an issue on Github, it seems like the only viable solution is to compare the samples values, since only weakrefs are created.

    Thus the solution is:

    import tensorflow_datasets as tfds
    
    # %% Train dataset
    (ds_train_original, ds_test_original), ds_info = tfds.load(
        "mnist",
        split=["train", "test"],
        shuffle_files=True,
        as_supervised=True,
        with_info=True,
    )
    
    iterator = iter(ds_train_original)
    el = iterator.get_next()[0]
    (el[0].numpy() == el[0].numpy()).all()