Search code examples
pythonnlppytorchclassificationbert-language-model

BERT text clasisification using pytorch


I am trying to build a BERT model for text classification with the help of this code [https://towardsdatascience.com/bert-text-classification-using-pytorch-723dfb8b6b5b]. My dataset contains two columns(label, text). The labels can have three values of (0,1,2). The code works without any error but all values of confusion matrix are 0. Is there something wrong with my code?

import matplotlib.pyplot as plt
import pandas as pd
import torch
from torchtext.data import Field, TabularDataset, BucketIterator, Iterator
import torch.nn as nn
from transformers import BertTokenizer, BertForSequenceClassification
import torch.optim as optim
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

import seaborn as sns

torch.manual_seed(42)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

MAX_SEQ_LEN = 128
PAD_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.pad_token)
UNK_INDEX = tokenizer.convert_tokens_to_ids(tokenizer.unk_token)


label_field = Field(sequential=False, use_vocab=False, batch_first=True, dtype=torch.float)
text_field = Field(use_vocab=False, tokenize=tokenizer.encode, lower=False, include_lengths=False, batch_first=True, fix_length=MAX_SEQ_LEN, pad_token=PAD_INDEX, unk_t>
fields = [('label', label_field), ('text', text_field)]
CLASSIFICATION_REPORT = "classification_report.jsonl"


train, valid, test = TabularDataset.splits(path='', train='train.csv', validation='validate.csv', test='test.csv', format='CSV', fields=fields, skip_header=True)

train_iter = BucketIterator(train, batch_size=16, sort_key=lambda x: len(x.text), device=device, train=True, sort=True, sort_within_batch=True)
valid_iter = BucketIterator(valid, batch_size=16, sort_key=lambda x: len(x.text), device=device, train=True, sort=True, sort_within_batch=True)
test_iter = Iterator(test, batch_size=16, device=device, train=False, shuffle=False, sort=False)

class BERT(nn.Module):
        def __init__(self):
                super(BERT, self).__init__()
                options_name = "bert-base-uncased"
                self.encoder = BertForSequenceClassification.from_pretrained(options_name, num_labels = 3)

        def forward(self, text, label):
                loss, text_fea = self.encoder(text, labels=label)[:2]
                return loss, text_fea

def train(model, optimizer, criterion = nn.BCELoss(), train_loader = train_iter, valid_loader = valid_iter, num_epochs = 5, eval_every = len(train_iter) // 2, file_pat>        running_loss = 0.0
        valid_running_loss = 0.0
        global_step = 0
        train_loss_list = []
        valid_loss_list = []
        global_steps_list = []

        model.train()

        for epoch in range(num_epochs):
                for (label, text), _ in train_loader:
                        label = label.type(torch.LongTensor)
                        label = label.to(device)
                        text = text.type(torch.LongTensor)
                        text = text.to(device)
                        output = model(text, label)
                        loss, _ = output
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                        running_loss += loss.item()
                        global_step += 1
                        if global_step % eval_every == 0:
                                model.eval()
                                with torch.no_grad():
                                        for (label, text), _ in valid_loader:
                                                label = label.type(torch.LongTensor)
                                                label = label.to(device)
                                                text = text.type(torch.LongTensor)
                                                text = text.to(device)
                                                output = model(text, label)
                                                loss, _ = output
                                                valid_running_loss += loss.item()

                                average_train_loss = running_loss / eval_every
                                average_valid_loss = valid_running_loss / len(valid_loader)
                                train_loss_list.append(average_train_loss)
                                valid_loss_list.append(average_valid_loss)
                                global_steps_list.append(global_step)


                                # resetting running values
                                running_loss = 0.0
                                valid_running_loss = 0.0
                                model.train()

                                # print progress
                                print('Epoch [{}/{}], Step [{}/{}], Train Loss: {:.4f}, Valid Loss: {:.4f}'.format(epoch+1, num_epochs, global_step, num_epochs*len(tra>
                                if best_valid_loss > average_valid_loss:
                                        best_valid_loss = average_valid_loss
        print('Finished Training!')

model = BERT().to(device)
optimizer = optim.Adam(model.parameters(), lr=2e-5)

train(model=model, optimizer=optimizer)


def evaluate(model, test_loader):
        y_pred = []
        y_true = []
        model.eval()
        with torch.no_grad():
                for (label, text), _ in test_loader:
                        label = label.type(torch.LongTensor)
                        label = label.to(device)
                        text = text.type(torch.LongTensor)
                        text = text.to(device)
                        output = model(text, label)

                        _, output = output
                        y_pred.extend(torch.argmax(output, 2).tolist())
                        y_true.extend(label.tolist())
        print('Classification Report:')
        print(classification_report(y_true, y_pred, labels=[0,1,2], digits=4))
best_model = BERT().to(device)
evaluate(best_model, test_iter)

Solution

  • you are using criterion = nn.BCELoss(), binary cross entropy for a multi class classification problem, "the labels can have three values of (0,1,2)". use suitable loss function for multiclass classification.