I came across a commit that includes to_list()
to convert a tensor shape to a list in an assertion as the right-hand operand was a list. I wanted to understand the reason for using this function, so reproduced it. But, I cannot figure out what is wrong in the code without using to_list()
. Here is my code.
x = tf.constant([[1, 2, 3], [4, 5, 6]])
xt = tf.transpose(x)
assert xt.shape == [3, 2], "Assertion failed" # this does not fail
print("Received the expected shape")
assert xt.shape.to_list() == [3, 2], "Assertion failed"
print("Received the expected shape")
I thought the first assertion should fail as xt.shape is a TensorShape
object (if is check isinstance(xt.shape, TensorShape)
, it returns True
) and xt.shape
returns a tuple.
I created Python a tuple and a list to understand what is going on.
a_list = [3, 2]
a_tuple = (3, 2)
assert a_list == a_tuple, "Assertion failed"
In this case, the assertion fails.
Since both use Python operators, why do they behave differently? More importantly, is there a better use case for to_list()
? It seems the function is redundant as I can compare the TensorShape with either list or tuples.
Whatever happens when you use ==
using a tensorflow’s TensorShape
object is controlled by TensorShape.__eq__
function, which overloads the equality operator so that it magically compares equal to Python lists (or tuples) that represent the same dimensions. In other words, although xt.shape
is a TensorShape object, its __eq__
method is implemented such that when you do:
assert xt.shape == [3, 2]
tensorflow internally checks that the shape dimensions matches the elements of the list [3, 2]
and returns True
. This is why the assertion passes even though you might expect a type mismatch.
You can check the details of how this comparison is made in the TensorShape
documentation. There, it says the __eq__
method
Returns
True
ifself
is equivalent toother
.It first tries to convert other to
TensorShape
.TypeError
is thrown when the conversion fails. Otherwise, it compares each element in the TensorShape dimensions.>>> t_a = tf.TensorShape([1,2]) >>> a = [1, 2] >>> t_b = tf.TensorShape([1,2]) >>> t_c = tf.TensorShape([1,2,3]) >>> t_a.__eq__(a) True >>> t_a.__eq__(t_b) True >>> t_a.__eq__(t_c) False