Search code examples
pythonbert-language-modelallennlp

Cannot register text_classifier as Model; name already in use for TextClassifier


Trying to use text classifier model shared by https://github.com/allenai/scibert/blob/master/scibert/models/text_classifier.py

Everything used to work and suddenly I keep getting this error: Cannot register text_classifier as Model; name already in use for TextClassifier

What might be the reason? any suggestion?

    from typing import Dict, Optional, List, Any
    
    import torch
    import torch.nn.functional as F
    from allennlp.data import Vocabulary
    from allennlp.models.model import Model
    from allennlp.modules import FeedForward, TextFieldEmbedder, Seq2SeqEncoder
    from allennlp.nn import InitializerApplicator, RegularizerApplicator
    from allennlp.nn import util
    from allennlp.training.metrics import CategoricalAccuracy, F1Measure
    from overrides import overrides
    
    
    @Model.register("text_classifier")
    class TextClassifier(Model):
        """
        Implements a basic text classifier:
        1) Embed tokens using `text_field_embedder`
        2) Seq2SeqEncoder, e.g. BiLSTM
        3) Append the first and last encoder states
        4) Final feedforward layer
        Optimized with CrossEntropyLoss.  Evaluated with CategoricalAccuracy & F1.
        """
    def __init__(self, vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 text_encoder: Seq2SeqEncoder,
                 classifier_feedforward: FeedForward,
                 verbose_metrics: False,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None,
                 ) -> None:
        super(TextClassifier, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.text_encoder = text_encoder
        self.classifier_feedforward = classifier_feedforward
        self.prediction_layer = torch.nn.Linear(self.classifier_feedforward.get_output_dim()  , self.num_classes)

        self.label_accuracy = CategoricalAccuracy()
        self.label_f1_metrics = {}

        self.verbose_metrics = verbose_metrics

        for i in range(self.num_classes):
            self.label_f1_metrics[vocab.get_token_from_index(index=i, namespace="labels")] = F1Measure(positive_label=i)
        self.loss = torch.nn.CrossEntropyLoss()

        self.pool = lambda text, mask: util.get_final_encoder_states(text, mask, bidirectional=True)

        initializer(self)

    @overrides
    def forward(self,
                text: Dict[str, torch.LongTensor],
                label: torch.IntTensor = None,
                metadata:  List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        """
        Parameters
        ----------
        text : Dict[str, torch.LongTensor]
            From a ``TextField``
        label : torch.IntTensor, optional (default = None)
            From a ``LabelField``
        metadata : ``List[Dict[str, Any]]``, optional, (default = None)
            Metadata containing the original tokenization of the premise and
            hypothesis with 'premise_tokens' and 'hypothesis_tokens' keys respectively.
        Returns
        -------
        An output dictionary consisting of:
        label_logits : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing unnormalised log probabilities of the label.
        label_probs : torch.FloatTensor
            A tensor of shape ``(batch_size, num_labels)`` representing probabilities of the label.
        loss : torch.FloatTensor, optional
            A scalar loss to be optimised.
        """
        embedded_text = self.text_field_embedder(text)

        mask = util.get_text_field_mask(text)
        encoded_text = self.text_encoder(embedded_text, mask)
        pooled = self.pool(encoded_text, mask)
        ff_hidden = self.classifier_feedforward(pooled)
        logits = self.prediction_layer(ff_hidden)
        class_probs = F.softmax(logits, dim=1)

        output_dict = {"logits": logits}
        if label is not None:
            loss = self.loss(logits, label)
            output_dict["loss"] = loss

            # compute F1 per label
            for i in range(self.num_classes):
                metric = self.label_f1_metrics[self.vocab.get_token_from_index(index=i, namespace="labels")]
                metric(class_probs, label)
            self.label_accuracy(logits, label)
        return output_dict

   #@overrides
    def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        class_probabilities = F.softmax(output_dict['logits'], dim=-1)
        output_dict['class_probs'] = class_probabilities
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        metric_dict = {}

        sum_f1 = 0.0
        for name, metric in self.label_f1_metrics.items():
            metric_val = metric.get_metric(reset)
            if self.verbose_metrics:
                metric_dict[name + '_P'] = metric_val[0]
                metric_dict[name + '_R'] = metric_val[1]
                metric_dict[name + '_F1'] = metric_val[2]
            sum_f1 += metric_val[2]

        names = list(self.label_f1_metrics.keys())
        total_len = len(names)
        average_f1 = sum_f1 / total_len
        metric_dict['average_F1'] = average_f1
        metric_dict['accuracy'] = self.label_accuracy.get_metric(reset)
        return metric_dict

Solution

  • The name is already taken. Something that’s already a part of AllenNLP uses that name already, so you need to pick a different one.

    For the curious, AllenNLP creates a registry of models, so that you can select a model at the command line. (That’s what the decorator is doing.) This requires the names to be unique.

    The name text_classifier was used by AllenNLP only after the external package you’re using used it. It worked in May 2019, when that file was last updated. But 17 months ago, AllenNLP started using it. So it’s not your fault; it’s a mismatch between those two packages (at least, in their current versions).