I am loading a model like this:
id2label = {
0: 'background',
1: 'cake',
2: 'donut',
}
model = Mask2FormerForUniversalSegmentation.from_pretrained(self.backbone, id2label=self.id2label, ignore_mismatched_sizes=True)
model.load_state_dict(torch.load('checkpoint.pt', map_location=torch.device('cpu')))
However I don't actually know id2label
. (I just have the checkpoint). I don't really care about the names of the classes I just want to know how many there are in the checkpointed model. I can see it in the warning message that appears, but would like to avoid that:
RuntimeError: Error(s) in loading state_dict for Mask2FormerForUniversalSegmentation:
size mismatch for class_predictor.weight: copying a param with shape torch.Size([8, 256]) from checkpoint, the shape in current model is torch.Size([20, 256]).
size mismatch for class_predictor.bias: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([20]).
size mismatch for criterion.empty_weight: copying a param with shape torch.Size([8]) from checkpoint, the shape in current model is torch.Size([20]).
You could check the state_dict:
import torch
chk = torch.load(checkpoint.pt)
# chk is a dict[str, torch.tensor]
# The layer shape tells you the number of labels +1 (i.e. subtract 1)
chk["class_predictor.weight"].shape[0]
The only downside is that you need to know the layer's name, but that is feasible when you only load with Mask2FormerForUniversalSegmentation
.