Search code examples
python-3.xtensorflow2.0assertion

What is the purpose of `to_list()` function on TensorFlow's TesnorShape objects


I came across a commit that includes to_list() to convert a tensor shape to a list in an assertion as the right-hand operand was a list. I wanted to understand the reason for using this function, so reproduced it. But, I cannot figure out what is wrong in the code without using to_list(). Here is my code.

x = tf.constant([[1, 2, 3], [4, 5, 6]])
xt = tf.transpose(x)

assert xt.shape == [3, 2], "Assertion failed" # this does not fail
print("Received the expected shape")

assert xt.shape.to_list() == [3, 2], "Assertion failed"
print("Received the expected shape")

I thought the first assertion should fail as xt.shape is a TensorShape object (if is check isinstance(xt.shape, TensorShape), it returns True) and xt.shape returns a tuple.

I created Python a tuple and a list to understand what is going on.

a_list = [3, 2]
a_tuple = (3, 2)

assert a_list == a_tuple, "Assertion failed"

In this case, the assertion fails.

Since both use Python operators, why do they behave differently? More importantly, is there a better use case for to_list()? It seems the function is redundant as I can compare the TensorShape with either list or tuples.


Solution

  • Whatever happens when you use == using a tensorflow’s TensorShape object is controlled by TensorShape.__eq__ function, which overloads the equality operator so that it magically compares equal to Python lists (or tuples) that represent the same dimensions. In other words, although xt.shape is a TensorShape object, its __eq__ method is implemented such that when you do:

    assert xt.shape == [3, 2]
    

    tensorflow internally checks that the shape dimensions matches the elements of the list [3, 2] and returns True. This is why the assertion passes even though you might expect a type mismatch.

    You can check the details of how this comparison is made in the TensorShape documentation. There, it says the __eq__ method

    Returns True if self is equivalent to other.

    It first tries to convert other to TensorShape. TypeError is thrown when the conversion fails. Otherwise, it compares each element in the TensorShape dimensions.

    >>> t_a = tf.TensorShape([1,2])
    >>> a = [1, 2]
    >>> t_b = tf.TensorShape([1,2])
    >>> t_c = tf.TensorShape([1,2,3])
    >>> t_a.__eq__(a)
    True
    >>> t_a.__eq__(t_b)
    True
    >>> t_a.__eq__(t_c)
    False