Search code examples
pytorchmathematical-optimization

Pytorch: Differentiable counting


I need to count the number of elements in a tensor that satisfy a condition, such as counting the number of people with age == 60, or people with age >= 50. Is there a differentiable approximation to the counting function?


Solution

  • Use torch.Tensor.sum or torch.sum on boolean tensor returned by age >= 50

    age = torch.arange(75)
    
    (age >= 50).sum()
    tensor(25)