Search code examples
pythonkeras

How can I get the number of trainable parameters of a model in Keras?


I am setting trainable=False in all my layers, implemented through the Model API, but I want to verify whether that is working. model.count_params() returns the total number of parameters, but is there any way in which I can get the total number of trainable parameters, other than looking at the last few lines of model.summary()?


Solution

  • from keras import backend as K
    
    trainable_count = int(
        np.sum([K.count_params(p) for p in set(model.trainable_weights)]))
    non_trainable_count = int(
        np.sum([K.count_params(p) for p in set(model.non_trainable_weights)]))
    
    print('Total params: {:,}'.format(trainable_count + non_trainable_count))
    print('Trainable params: {:,}'.format(trainable_count))
    print('Non-trainable params: {:,}'.format(non_trainable_count))
    

    The above snippet can be discovered in the end of layer_utils.print_summary() definition, which summary() is calling.


    Edit: more recent version of Keras has a helper function count_params() for this purpose:

    from keras.utils.layer_utils import count_params
    
    trainable_count = count_params(model.trainable_weights)
    non_trainable_count = count_params(model.non_trainable_weights)