Search code examples
pythonpytorch

Check the total number of parameters in a PyTorch model


How do I count the total number of parameters in a PyTorch model? Something similar to model.count_params() in Keras.


Solution

  • PyTorch doesn't have a function to calculate the total number of parameters as Keras does, but it's possible to sum the number of elements for every parameter group:

    pytorch_total_params = sum(p.numel() for p in model.parameters())
    

    If you want to calculate only the trainable parameters:

    pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    

    Answer inspired by this answer on PyTorch Forums.