Search code examples
pythonpytorchhuggingface-transformers

Can you determine the number of output classes in a HuggingFace segmentation model?


I am loading a model like this:

id2label = {
        0: 'background', 
        1: 'cake',
        2: 'donut', 
}
model = Mask2FormerForUniversalSegmentation.from_pretrained(self.backbone, id2label=self.id2label, ignore_mismatched_sizes=True)  
model.load_state_dict(torch.load('checkpoint.pt', map_location=torch.device('cpu')))

However I don't actually know id2label. (I just have the checkpoint). I don't really care about the names of the classes I just want to know how many there are in the checkpointed model. I can see it in the warning message that appears, but would like to avoid that:

RuntimeError: Error(s) in loading state_dict for Mask2FormerForUniversalSegmentation:
    size mismatch for class_predictor.weight: copying a param with shape torch.Size([8, 256]) from checkpoint, the shape in current model is torch.Size([20, 256]).
    size mismatch for class_predictor.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([20]).
    size mismatch for criterion.empty_weight: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([20]).

Solution

  • You could check the state_dict:

    import torch
    
    chk = torch.load(checkpoint.pt)
    # chk is a dict[str, torch.tensor]
    # The layer shape tells you the number of labels +1 (i.e. subtract 1)
    chk["class_predictor.weight"].shape[0] 
    

    The only downside is that you need to know the layer's name, but that is feasible when you only load with Mask2FormerForUniversalSegmentation.