Search code examples
pythondeep-learningpytorchdistributedtransformer-model

How to remove a prediction head from pytorch model based on the output tensor?


I am working on a ViT (Vision Transformer) related project and some low level definition is deep inside timm library, which I can not change. The low level library definition involves a linear classification prediction head, which is not a part of my network.

Every thing was fine until I switched to DDP parallel implementation. Pytorch complained about some parameters which didn’t contribute to the loss, and it instructed me to use “find_unused_parameters=True”. In fact, it is a common scenario and it worked again if I added this “find_unused_parameters=True” to the training routine. However, I am only allowed to change the model definition in our code base, but I cannot modify anything related to training …

So I guess the only thing I can do right now, is to “remove” the linear head from the model. Although I cannot dig into the low level definition of ViT, but I can output this tensor like this:

encoder_output,   linear_head_output =  ViT(input)

Is it possible to remove this linear prediction head based on this linear_head_output tensor?


Solution

  • Just set the num_classes=0 when you create your ViT model by calling timm.create_model().

    Here is an example from TIMM documentation on Feature Extraction:

    import torch
    import timm
    m = timm.create_model('resnet50', pretrained=True, num_classes=0, global_pool='')
    o = m(torch.randn(2, 3, 224, 224))
    print(f'Unpooled shape: {o.shape}')