I am using allennlp 2.1 and I would like to pass class weights to the pytorch-cross-entropy loss function that I use.
@Head.register('model_head_two_layers')
class ModelHeadTwoLayers(Head):
default_predictor = 'head_predictor'
def __init__(self, vocab: Vocabulary, input_dim: int, output_dim: int, dropout: float = 0.0,
class_weights: Union[List[float], None] = None):
super().__init__(vocab=vocab)
self.input_dim = input_dim
self.output_dim = output_dim
self.layers = torch.nn.Sequential(
torch.nn.Dropout(dropout),
torch.nn.Linear(self.input_dim, self.input_dim),
torch.nn.ReLU(inplace=True),
torch.nn.Linear(self.input_dim, output_dim)
)
self.metrics = {
'accuracy': CategoricalAccuracy(),
'f1_macro': FBetaMeasure(average='macro')
}
if class_weights:
self.class_weights = torch.FloatTensor(class_weights)
self.cross_ent = torch.nn.CrossEntropyLoss(weight=self.class_weights)
else:
self.cross_ent = torch.nn.CrossEntropyLoss()
In the configuration file I pass the class weights as follows:
"heads": {
"task_name": {
"type": "model_head_two_layers",
"input_dim": embedding_dim,
"output_dim": 4,
"dropout": dropout,
"class_weights": [0.25, 0.90, 0.91, 0.94]
}
}
In order for the class weights to be in the correct order I need to know which index of the output tensor corresponds to which class. The only way to find that out, that I know of until now, is to first train a model without class weights and then go into the vocabulary directory of the model and check in which order the class names are written into the labels-file.
While that seems to work...is there an easier way to get that mapping without having to train a model first?
You can generate a vocabulary without training a model by using the allennlp build-vocab
command. But I think the better solution here would be to pass the class_weights
to your model as a mapping from label -> weight
, and then build the array of weights with the __init__
function. Something like this:
class ModelHeadTwoLayers(Head):
def __init__(
self,
vocab: Vocabulary,
input_dim: int,
output_dim: int,
dropout: float = 0.0,
class_weights: Optional[Dict[str, float]] = None,
label_namespace: str = "labels",
):
super().__init__(vocab=vocab)
# ...
if class_weights:
weights: List[float] = [0.0] * len(class_weights)
for label, weight in class_weights.items():
label_idx = self.vocab.get_token_index(label, namespace=label_namespace)
weights[label_idx] = weight
self.class_weights = torch.FloatTensor(weights)
self.cross_ent = torch.nn.CrossEntropyLoss(weight=self.class_weights)
else:
self.cross_ent = torch.nn.CrossEntropyLoss()