I need to count the number of elements in a tensor that satisfy a condition, such as counting the number of people with age == 60, or people with age >= 50. Is there a differentiable approximation to the counting function?
Use torch.Tensor.sum
or torch.sum
on boolean tensor returned by age >= 50
age = torch.arange(75)
(age >= 50).sum()
tensor(25)