Search code examples

Problem with custom metric for custom T5 model

I have created a custom dataset and trained on it a custom T5ForConditionalGeneration model that predicts solutions to quadratic equations like this:

Input: "4*x^2 + 4*x + 1" Output: D = 4 ^ 2 - 4 * 4 * 1 4 * 1 4 * 1 4 * 1 4 * 1 4

I need to get accuracy for this model but I get only loss when I use Trainer so I used a custom metric function (I didn't write it but took it from a similar project):

def compute_metrics4token(eval_pred):
    batch_size = 4
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    # Rouge expects a newline after each sentence
    decoded_preds =  ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels =  ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    answer_accuracy = []
    token_accuracy = []
    num_correct, num_total = 0, 0
    num_answer = 0
    number_eq = 0
    for p, l in zip(decoded_preds, decoded_labels):
        text_pred = p.split(' ')
        text_labels = l.split(' ')
        m = min(len(text_pred), len(text_labels))
        if np.array_equal(text_pred, text_labels):
            num_answer += 1
        for i, j in zip(text_pred, text_labels):
            if i == j:
                num_correct += 1
        num_total += len(text_labels)
        number_eq += 1
    token_accuracy = num_correct / num_total
    answer_accuracy = num_answer / number_eq
    result = {'token_acc': token_accuracy, 'answer_acc': answer_accuracy}
    result = {key: value for key, value in result.items()}
    for key, value in result.items():
        wandb.log({key: value})        
    return {k: round(v, 4) for k, v in result.items()}

Problem is that it doesn't work and I don't really understand why and what can I do to get accuracy for my model. I get this error when I use the function:

args = Seq2SeqTrainingArguments(
    overwrite_output_dir = True,
    evaluation_strategy = 'steps',         
    learning_rate = 1e-4,                 
    logging_steps = 100,                    
    eval_steps = 100,                      
    save_steps = 100,
    load_best_model_at_end = True,
    weight_decay = 0.01,

trainer = Seq2SeqTrainer(model=model, train_dataset=train_dataset, eval_dataset=eval_dataset, args=args, 
                  data_collator=data_collator, tokenizer=tokenizer, compute_metrics=compute_metrics4token)
<ipython-input-48-ff7980f6dd66> in compute_metrics4token(eval_pred)
      4     # predictions = np.argmax(logits[0])
      5     # print(predictions)
----> 6     decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
      7     # Replace -100 in the labels as we can't decode them.
      8     labels = np.where(labels != -100, labels, tokenizer.pad_token_id)

/usr/local/lib/python3.10/dist-packages/transformers/ in batch_decode(self, sequences, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
   3444             `List[str]`: The list of decoded sentences.
   3445         """
-> 3446         return [
   3447             self.decode(
   3448                 seq,

/usr/local/lib/python3.10/dist-packages/transformers/ in <listcomp>(.0)
   3445         """
   3446         return [
-> 3447             self.decode(
   3448                 seq,
   3449                 skip_special_tokens=skip_special_tokens,

/usr/local/lib/python3.10/dist-packages/transformers/ in decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
   3484         token_ids = to_py_obj(token_ids)
-> 3486         return self._decode(
   3487             token_ids=token_ids,
   3488             skip_special_tokens=skip_special_tokens,

/usr/local/lib/python3.10/dist-packages/transformers/ in _decode(self, token_ids, skip_special_tokens, clean_up_tokenization_spaces, **kwargs)
    547         if isinstance(token_ids, int):
    548             token_ids = [token_ids]
--> 549         text = self._tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
    551         clean_up_tokenization_spaces = (

TypeError: argument 'ids': 'list' object cannot be interpreted as an integer

When I print out predictions I get a tuple:

(array([[[-32.777344, -34.593437, -36.065685, ..., -34.78577 ,
         -34.77546 , -34.061115],
        [-58.633934, -32.23472 , -31.735909, ..., -40.335655,
         -40.28701 , -37.208904],
        [-56.650974, -33.564095, -34.409576, ..., -36.94467 ,
         -43.246735, -37.469246],
        [-56.62741 , -24.561722, -34.11228 , ..., -35.34798 ,
         -42.287125, -38.889412],
        [-56.632545, -24.470266, -34.0792  , ..., -35.313175,
         -42.235626, -38.891712],
        [-56.687027, -24.391508, -34.12526 , ..., -35.30828 ,
         -42.204193, -38.88395 ]],

       [[-29.79866 , -32.22621 , -32.689865, ..., -32.106445,
         -31.46681 , -31.706667],
        [-62.101192, -33.327423, -30.900173, ..., -38.046883,
         -42.26345 , -38.97748 ],
        [-54.726807, -29.13115 , -30.294558, ..., -28.370876,
         -41.23722 , -37.91609 ],
        [-57.279373, -23.954525, -34.066246, ..., -35.047447,
         -41.599922, -38.489853],
        [-57.31298 , -23.879845, -34.0837  , ..., -35.03614 ,
         -41.557755, -38.530064],
        [-57.39132 , -23.831306, -34.120094, ..., -35.039547,
         -41.525337, -38.55728 ]],

       [[-29.858566, -32.452713, -34.05892 , ..., -33.93065 ,
         -32.109177, -32.874695],
        [-61.375793, -33.656853, -32.95248 , ..., -42.28087 ,
         -42.637173, -39.21142 ],
        [-58.43721 , -32.496166, -36.44046 , ..., -39.33864 ,
         -42.139664, -38.695328],
        [-59.654663, -24.117435, -34.266438, ..., -35.734142,
         -40.55384 , -38.467537],
        [-38.54418 , -18.533113, -29.775307, ..., -26.856483,
         -33.07976 , -29.934727],
        [-27.716005, -14.610603, -23.752686, ..., -21.140053,
         -26.855148, -24.429493]],


       [[-33.252697, -34.72487 , -36.395184, ..., -36.87368 ,
         -35.207897, -34.468285],
        [-59.911736, -32.730076, -32.622803, ..., -43.382267,
         -42.25615 , -38.35135 ],
        [-54.982887, -31.847572, -32.773827, ..., -38.500675,
         -43.97969 , -37.41088 ],
        [-56.896988, -23.213766, -34.04734 , ..., -35.88832 ,
         -42.176086, -38.953568],
        [-56.994152, -23.141619, -34.054848, ..., -35.875816,
         -42.176453, -38.97729 ],
        [-57.076714, -23.05831 , -34.048904, ..., -35.888298,
         -42.165287, -39.020435]],

       [[-30.070187, -32.049232, -34.63928 , ..., -35.02118 ,
         -32.14465 , -32.891876],
        [-61.720093, -32.994057, -32.988144, ..., -42.054638,
         -42.18583 , -38.990112],
        [-57.74364 , -31.431454, -35.969643, ..., -38.593002,
         -42.276768, -38.895355],
        [-58.677704, -23.567434, -35.6751  , ..., -36.018696,
         -40.343582, -38.681267],
        [-58.682228, -23.563087, -35.668964, ..., -36.019753,
         -40.336178, -38.67661 ],
        [-58.718002, -23.609531, -35.67758 , ..., -36.001644,
         -40.366055, -38.67864 ]],

       [[-30.320919, -33.430378, -34.84311 , ..., -37.259563,
         -32.59662 , -33.03912 ],
        [-61.275875, -34.824192, -34.07767 , ..., -44.637024,
         -41.718002, -38.974827],
        [-54.49349 , -30.689342, -35.539658, ..., -39.984665,
         -39.87059 , -37.038437],
        [-58.939384, -23.831846, -34.525368, ..., -35.930893,
         -40.29633 , -37.637936],
        [-58.95117 , -23.824234, -34.520042, ..., -35.931396,
         -40.297188, -37.636852],
        [-58.966076, -23.795956, -34.519627, ..., -35.901787,
         -40.261116, -37.612514]]], dtype=float32), array([[[-1.43104442e-03, -2.98473001e-01,  9.49775204e-02, ...,
         -1.77978892e-02,  1.79805323e-01,  1.33578405e-01],
        [-2.35560730e-01,  1.53045550e-01,  5.15255742e-02, ...,
         -1.57466665e-01,  3.49459350e-01,  7.28092641e-02],
        [ 1.60562042e-02, -1.40354022e-01,  5.29232398e-02, ...,
         -2.38162443e-01, -7.72500336e-02,  6.80136457e-02],
        [ 7.33550191e-02, -3.35853845e-01,  2.25579832e-03, ...,
         -1.93636306e-02,  1.08121082e-01,  5.24416938e-02],
        [ 8.32231194e-02, -3.11688155e-01, -2.13681534e-02, ...,
          3.23344418e-03,  1.08062990e-01,  7.20862746e-02],
        [ 9.58326831e-02, -3.00361574e-01, -3.02627794e-02, ...,
          3.01265554e-03,  1.20107472e-01,  9.56629887e-02]],

       [[-1.16950013e-01, -3.43173921e-01,  1.87818244e-01, ...,
         -2.71256089e-01,  7.42092952e-02,  5.77520356e-02],
        [-1.62564963e-01, -3.87467295e-01,  1.71134964e-01, ...,
         -7.83916116e-02, -3.65173034e-02,  2.08234787e-01],
        [-3.71523261e-01, -8.74521434e-02,  1.39187068e-01, ...,
         -3.08779895e-01,  3.88156146e-01,  9.99216512e-02],
        [ 2.14628279e-02, -3.35561454e-01, -3.76663893e-03, ...,
         -1.29795140e-02,  1.44181430e-01,  1.15508482e-01],
        [ 3.47745977e-02, -3.30934107e-01,  1.10013550e-02, ...,
         -1.84394475e-02,  1.52143195e-01,  1.38157398e-01],
        [ 3.02720107e-02, -3.37626845e-01,  1.35379741e-02, ...,
         -3.80427912e-02,  1.50906458e-01,  1.38765752e-01]],

       [[-6.50129542e-02, -2.63762653e-01,  2.16862872e-01, ...,
         -1.66922837e-01,  1.09285273e-01, -6.40013069e-02],
        [-5.23199737e-01, -2.32228413e-01,  1.44963071e-01, ...,
         -1.41557693e-01,  1.90811172e-01, -2.22496167e-01],
        [-2.24985227e-01, -3.69372189e-01,  7.32450858e-02, ...,
          6.57786876e-02,  9.70033705e-02,  7.83021152e-02],
        [-1.93579309e-03, -3.92921537e-01, -1.28203649e-02, ...,
         -8.74079913e-02,  1.13596492e-01,  9.25250202e-02],
        [ 4.55581211e-03, -3.65802884e-01, -2.60831695e-02, ...,
         -4.12549600e-02,  1.17429778e-01,  1.05997331e-01],
        [ 2.46201381e-02, -3.47863257e-01, -4.48134281e-02, ...,
         -2.53352951e-02,  1.16753690e-01,  1.36296600e-01]],


       [[-6.47678748e-02, -3.45555365e-01,  7.19114989e-02, ...,
         -9.16809738e-02,  2.15520635e-01,  1.01671875e-01],
        [-7.61077851e-02, -1.51827012e-03,  9.52102616e-02, ...,
         -1.39335945e-01,  1.05894208e-01,  3.23191588e-03],
        [-3.24888170e-01, -2.17741728e-03,  5.32661797e-03, ...,
         -2.78430730e-01,  3.59415114e-01,  1.19439401e-01],
        [ 6.89201057e-02, -3.63149673e-01,  7.96841756e-02, ...,
         -3.25191446e-04,  1.26513481e-01,  1.36511743e-01],
        [ 8.16355348e-02, -3.54205281e-01,  7.69739375e-02, ...,
         -2.90949806e-03,  1.31863236e-01,  1.56503588e-01],
        [ 8.36645439e-02, -3.38536322e-01,  8.00612345e-02, ...,
         -9.39210225e-03,  1.29102767e-01,  1.64855778e-01]],

       [[-1.63163885e-01, -3.34902078e-01,  1.11728966e-01, ...,
         -1.10363133e-01,  1.19786285e-01, -9.18702483e-02],
        [-3.36889774e-01, -3.34888607e-01,  1.30680993e-01, ...,
          1.22191897e-03,  1.45059675e-01, -1.27688542e-01],
        [-5.92090450e-02, -2.07585752e-01,  2.05589265e-01, ...,
         -6.80094585e-02,  2.11224273e-01,  3.92790437e-01],
        [ 4.86238785e-02, -4.19503808e-01, -3.39424387e-02, ...,
         -1.76134892e-02,  1.00283481e-01,  1.38210282e-01],
        [ 5.81516996e-02, -4.04477298e-01, -4.19086292e-02, ...,
         -1.02474755e-02,  1.06062084e-01,  1.59754634e-01],
        [ 6.70261905e-02, -3.86263877e-01, -4.19785343e-02, ...,
          9.05385148e-03,  1.01594023e-01,  1.69663757e-01]],

       [[-1.22184128e-01, -3.67584258e-01,  3.60302597e-01, ...,
         -4.39502299e-02,  1.33717149e-01,  1.53699834e-02],
        [-3.37780178e-01, -4.05100137e-01,  2.02614054e-01, ...,
         -5.41410968e-02,  1.55447468e-01, -9.28792357e-02],
        [ 1.81227952e-01, -2.29236633e-01,  2.40814224e-01, ...,
          1.39913429e-02,  7.61386827e-02,  3.62152725e-01],
        [ 1.47830993e-02, -4.26465064e-01, -1.54972840e-02, ...,
          3.74358669e-02,  1.52016997e-01,  1.53155088e-01],
        [ 3.46656404e-02, -4.00052220e-01, -3.53843644e-02, ...,
          2.64652576e-02,  1.62517026e-01,  1.66649833e-01],
        [ 4.50411513e-02, -3.61773074e-01, -5.50217964e-02, ...,
          3.68298292e-02,  1.67936400e-01,  1.76781893e-01]]],

I thought that maybe I need to take argmax from these values but then I still get errors.

If something is unclear I would be happy to provide additional information. Thanks for any help.


I am adding an example of an item in the dataset:


{'text': ['3*x^2 + 9*x + 6 = 0',
'59*x^2 + -59*x + 14 = 0',
'-10*x^2 + 0*x + 0 = 0',
'3*x^2 + 63*x + 330 = 0',
'1*x^2 + -25*x + 156 = 0'],
'label': ['D = 9^2 - 4 * 3 * 6 = 9; x1 = (-9 + (9)**0.5) // (2 * 3) 
= -1.0; x2 = (-9 - (9)**0.5) // (2 * 3) = -2.0',
'D = -59^2 - 4 * 59 * 14 = 177; x1 = (59 + (177)**0.5) // (2 * 59) 
= 0.0; x2 = (59 - (177)**0.5) // (2 * 59) = 0.0',
'D = 0^2 - 4 * -10 * 0 = 0; x = 0^2 // (2 * -10) = 0',
'D = 63^2 - 4 * 3 * 330 = 9; x1 = (-63 + (9)**0.5) // (2 * 3) = 
-10.0; x2 = (-63 - (9)**0.5) // (2 * 3) = -11.0',
'D = -25^2 - 4 * 1 * 156 = 1; x1 = (25 + (1)**0.5) // (2 * 1) = 
13.0; x2 = (25 - (1)**0.5) // (2 * 1) = 12.0'],
'__index_level_0__': [10803, 14170, 25757, 73733, 25059]}


  • It seems like the task you're trying to achieve is some sort of "translation" task so the most appropriate model is to use the AutoModelForSeq2SeqLM.

    And in the case of unspecified sequence, it might be more appropriate to use

    • BLEU / ChrF or newer neural-based metrics for translation
    • ROUGE for summarization

    You can take a look at various translation-related metrics on

    Treating it as a normal Machine Translation task

    To read the data, you'll have to make sure that the model's forward function

    • sees the data point as {"text": [0, 1, 2, ... ], "labels": [0, 9, 8, ...]} in your datasets.Dataset object
    • use the collator to do batch, e.g. DataCollatorForSeq2Seq

    And here's a working snippet of how the code (in parts) can be ran:

    Data processing part.

    from datasets import Dataset
    import evaluate
    from transformers import AutoModelForSeq2SeqLM, Trainer, AutoTokenizer, DataCollatorForSeq2Seq
    math_data = {'text': ['3*x^2 + 9*x + 6 = 0',
      '59*x^2 + -59*x + 14 = 0',
      '-10*x^2 + 0*x + 0 = 0',
      '3*x^2 + 63*x + 330 = 0',
      '1*x^2 + -25*x + 156 = 0'],
     'target': ['D = 9^2 - 4 * 3 * 6 = 9; x1 = (-9 + (9)**0.5) // (2 * 3)  = -1.0; x2 = (-9 - (9)**0.5) // (2 * 3) = -2.0',
      'D = -59^2 - 4 * 59 * 14 = 177; x1 = (59 + (177)**0.5) // (2 * 59)  = 0.0; x2 = (59 - (177)**0.5) // (2 * 59) = 0.0',
      'D = 0^2 - 4 * -10 * 0 = 0; x = 0^2 // (2 * -10) = 0',
      'D = 63^2 - 4 * 3 * 330 = 9; x1 = (-63 + (9)**0.5) // (2 * 3) =  -10.0; x2 = (-63 - (9)**0.5) // (2 * 3) = -11.0',
      'D = -25^2 - 4 * 1 * 156 = 1; x1 = (25 + (1)**0.5) // (2 * 1) =  13.0; x2 = (25 - (1)**0.5) // (2 * 1) = 12.0']}
    math_data_eval = {'text': ["10 + 9x(x+3y) - 3x^3"], "target": ["10 + 9x^2 + 27xy - 3x^3"]}
    ds_train = Dataset.from_dict(math_data)
    model = AutoModelForSeq2SeqLM.from_pretrained("t5-small")
    tokenizer = AutoTokenizer.from_pretrained("t5-small")
    data_collator = DataCollatorForSeq2Seq(tokenizer)
    ds_train = x: tokenizer(x["text"], truncation=True, padding="max_length", max_length=512)
    ds_train = y: 
        {"labels": tokenizer(y["target"], truncation=True, padding="max_length", max_length=512)['input_ids']}
    ds_eval = Dataset.from_dict(math_data_eval)
    ds_eval = x: tokenizer(x["text"], 
        truncation=True, padding="max_length", max_length=512))
    ds_eval = y: 
        {"labels": tokenizer(y["target"], truncation=True, padding="max_length", max_length=512)['input_ids']}

    Metric definition part.

    import numpy as np
    metric = evaluate.load("sacrebleu")
    def postprocess_text(preds, labels):
        preds = [pred.strip() for pred in preds]
        labels = [[label.strip()] for label in labels]
        return preds, labels
    def compute_metrics(eval_preds):
        preds, labels = eval_preds
        if isinstance(preds, tuple):
            preds = preds[0]
        # Replace -100s used for padding as we can't decode them
        preds = np.where(preds != -100, preds, tokenizer.pad_token_id)
        decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
        labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        # Some simple post-processing
        decoded_preds, decoded_labels = postprocess_text(decoded_preds, decoded_labels)
        result = metric.compute(predictions=decoded_preds, references=decoded_labels)
        result = {"bleu": result["score"]}
        prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds]
        result["gen_len"] = np.mean(prediction_lens)
        result = {k: round(v, 4) for k, v in result.items()}
        return result

    Trainer setup part.

    from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
    # set training arguments - these params are not really tuned, feel free to change
    training_args = Seq2SeqTrainingArguments(
        logging_steps=2,  # set to 1000 for full training
        save_steps=16,    # set to 500 for full training
        eval_steps=4,     # set to 8000 for full training
        warmup_steps=1,   # set to 2000 for full training
        max_steps=16,     # delete for full training
        # overwrite_output_dir=True,
    # instantiate trainer
    trainer = Seq2SeqTrainer(

    That works and good. But why is the output of the model still so bad?

    • Most probably you need to tune some hyperparameter, batch_size, more data, different learning rates or increase no. of max_steps
    • It can also be that your vocab is pretrained for natural language but your data isn't, in that case, I'll suggest to try modifying the tokenizer before training, e.g. How to add new tokens to an existing Huggingface tokenizer?