Search code examples
pythonpytorch

difference between all() and .all() for checking if an iterable is True everywhere


I believe there are two ways of checking if a torch.Tensor has values all greater than 0. Either with .all() or all(), a Minimal Reproducible Example will illustrate my idea:

import torch

walls = torch.tensor([-1, 0, 1, 2])

result1 = (walls >= 0.0).all()  # DIFFERENCE WITH BELOW???
result2 = all(walls >= 0.0)  # DIFFERENCE WITH ABOVE???

print(result1)  # Output: False
print(result2)  # Output: False

all() is builtin so I think I would prefer using that one, but most code I see on the internet uses .all() so I'm afraid there is unexpected behaviour.

Are they both behaving the exact same?


Solution

  • all is Python builtin, meaning that it only works using extremely generic interfaces. In this case, all treats the tensor as an opaque iterable. It proceeds by iterating the elements of the tensor one-by-one, constructing a Python object for each one and then checking the truthiness of that Python object. That is slow, with several added layers of unnecessary inefficiency.

    In contrast, Tensor.all knows what a Tensor object is, and can operate on it directly. It only needs to directly scan the tensors internal storage. No iterator protocol function calls, no intermediate Python objects.

    Tensor.all will always be far more efficient, both in time and memory, that the builtin all.