Search code examples
pythonscikit-learnpytorchnlppython-polars

Keep training pytorch model on new data


I'm working on a text classification task and have decided to use a PyTorch model for this purpose. The process mainly involves the following steps:

  1. Load and process the text.
  2. Use a TF-IDF Vectorizer.
  3. Build the neural network and save the TF-IDF Vectorizer and model to predict new data.

However, every day I need to classify new comments and correct any wrong classifications.

Currently, my approach is to add the new comments with the correct classification to the dataset and retrain the entire model. This process is time-consuming, and the new comments can be lost during validation. I would like to create a new dataset with the newly classified texts and continue training over this new data (the new comments are classified manually, so each label is correct).

Using GPT and some online code, i write the desired process, however, im not sure if its working as expected, or im making some silly mistakes that should not happen.

So the mains questions are:

  1. How could i check if the propossed way to solve this problem work as i expect?
  2. What can i do with the vectorizer when it face new tokens, can i just do a .fit_transform() or i would loose the original vectorizer?

Here its the full training process:

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
from sklearn.preprocessing import LabelEncoder
import polars as pl
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
import joblib

set1 = (
    pl
    .read_csv(
        "set1.txt",
        separator=";",
        has_header=False,
        new_columns=["text","label"]
    )
)

# since the dateset its unbalanced, im going to force to have more balance

fear_df = set1.filter(pl.col("label") == "fear")
joy_df = set1.filter(pl.col("label") == "joy").sample(n=2500)
sadness_df = set1.filter(pl.col("label") == "sadness").sample(n=2500)
anger_df = set1.filter(pl.col("label") == "anger")

train_df = pl.concat([fear_df,joy_df,sadness_df,anger_df])

"""
The text its already clean, so im going to change the labels to numeric
and then split it on train, test ,val
"""

label_mapping = {
    "anger": 0,
    "fear": 1,
    "joy": 2,
    "sadness": 3
}

train_mapped = (
    train_df
    .with_columns(
        pl.col("label").replace_strict(label_mapping, default="other").cast(pl.Int16)
    )
   
)

train_set, pre_Test = train_test_split(train_mapped,
                                    test_size=0.4,
                                    random_state=42,
                                    stratify=train_mapped["label"])

test_set, val_set = train_test_split(pre_Test,
                                    test_size=0.5,
                                    random_state=42,
                                    stratify=pre_Test["label"]) 

# Vectorize text data using TF-IDF
vectorizer = TfidfVectorizer(max_features=30000, ngram_range=(1, 2))

X_train_tfidf = vectorizer.fit_transform(train_set['text']).toarray()
X_val_tfidf = vectorizer.transform(val_set['text']).toarray()
X_test_tfidf = vectorizer.transform(test_set['text']).toarray()

y_train = train_set['label']
y_val = val_set['label']
y_test = test_set['label']

class TextDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        return text, label
    
train_dataset = TextDataset(X_train_tfidf, y_train)
val_dataset = TextDataset(X_val_tfidf, y_val)
test_dataset = TextDataset(X_test_tfidf, y_test)

batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

class TextClassificationModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(TextClassificationModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(64, 32)
        self.dropout2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(32, num_classes)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout1(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout2(x)
        x = torch.softmax(self.fc3(x), dim=1)
        return x
    
input_dim = X_train_tfidf.shape[1]
model = TextClassificationModel(input_dim, 4)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adamax(model.parameters())

# Training loop
num_epochs = 17
best_val_acc = 0.0
best_model_path = "modelbest.pth"

for epoch in range(num_epochs):
    model.train()
    for texts, labels in train_loader:
        texts, labels = texts.float(), labels.long()
        outputs = model(texts)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # Validation
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for texts, labels in val_loader:
            texts, labels = texts.float(), labels.long()
            outputs = model(texts)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_acc = correct / total
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), best_model_path)

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Val Acc: {val_acc:.4f}')

# Load the best model
model.load_state_dict(torch.load(best_model_path))

# Load the best model
model.load_state_dict(torch.load(best_model_path))

# Test the model
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for texts, labels in test_loader:
        texts, labels = texts.float(), labels.long()
        outputs = model(texts)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
test_acc = correct / total
print(f'Test Acc: {test_acc:.3f}')


# Save the TF-IDF vectorizer
vectorizer_path = "tfidf_vectorizer.pkl"
joblib.dump(vectorizer, vectorizer_path)

# Save the PyTorch model
model_path = "text_classification_model.pth"
torch.save(model.state_dict(), model_path)

Proposed code:

import torch
import joblib
import polars as pl
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.data import Dataset, DataLoader

# Load the saved TF-IDF vectorizer
vectorizer_path = "tfidf_vectorizer.pkl"
vectorizer = joblib.load(vectorizer_path)

input_dim = len(vectorizer.get_feature_names_out())

class TextClassificationModel(nn.Module):
    def __init__(self, input_dim, num_classes):
        super(TextClassificationModel, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(64, 32)
        self.dropout2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(32, num_classes)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout1(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout2(x)
        x = torch.softmax(self.fc3(x), dim=1)
        return x
    
# Load the saved PyTorch model
model_path = "text_classification_model.pth"
model = TextClassificationModel(input_dim, 4)
model.load_state_dict(torch.load(model_path))

# Map labels to numeric values
label_mapping = {"anger": 0, "fear": 1, "joy": 2, "sadness": 3}
sentiments = ["fear","joy","sadness","anger"]

new_data = (
    pl
    .read_csv(
        "set2.txt",
        separator=";",
        has_header=False,
        new_columns=["text","label"]
    )
    .filter(pl.col("label").is_in(sentiments))
    .with_columns(
        pl.col("label").replace_strict(label_mapping, default="other").cast(pl.Int16)
    )
    
)
# Vectorize the new text data using the loaded TF-IDF vectorizer
X_new = vectorizer.transform(new_data['text']).toarray()
y_new = new_data['label']

class TextDataset(Dataset):
    def __init__(self, texts, labels):
        self.texts = texts
        self.labels = labels
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        return text, label

batch_size = 10
   
# Create DataLoader for the new training data
new_train_dataset = TextDataset(X_new, y_new)
new_train_loader = DataLoader(new_train_dataset, batch_size=batch_size, shuffle=True)

# Define loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adamax(model.parameters())

num_epochs = 5
new_best_model_path = "modelbest.pth"
for epoch in range(num_epochs):
    model.train()
    for texts, labels in new_train_loader:
        texts, labels = texts.float(), labels.long()
        outputs = model(texts)
        loss = criterion(outputs, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        torch.save(model.state_dict(), new_best_model_path)
        
print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

# Save the PyTorch model
new_best_model_path = "new_moedl.pth"
torch.save(model.state_dict(), new_best_model_path)

The dataset can be found here


Solution

  • use pre-trained word embeddings like BertForSequenceClassification. These embeddings can handle unseen tokens more gracefully since they map words to continuous vectors based on semantic meaning, reducing the impact of unseen words.

    Model Training with BERT

    import torch
    from torch import nn, optim
    from torch.utils.data import DataLoader, Dataset
    from transformers import BertTokenizer, BertModel, BertForSequenceClassification
    from transformers import Trainer, TrainingArguments
    from sklearn.model_selection import train_test_split
    import polars as pl
    
    # Load and prepare data
    set1 = pl.read_csv("set1.txt", separator=";", has_header=False, new_columns=["text", "label"])
    
    # Balance dataset
    fear_df = set1.filter(pl.col("label") == "fear")
    joy_df = set1.filter(pl.col("label") == "joy").sample(n=2500)
    sadness_df = set1.filter(pl.col("label") == "sadness").sample(n=2500)
    anger_df = set1.filter(pl.col("label") == "anger")
    train_df = pl.concat([fear_df, joy_df, sadness_df, anger_df])
    
    label_mapping = {"anger": 0, "fear": 1, "joy": 2, "sadness": 3}
    train_df = train_df.with_columns(pl.col("label").replace_strict(label_mapping, default="other").cast(pl.Int16))
    
    # Split dataset
    train_set, test_val_set = train_test_split(train_df, test_size=0.4, random_state=42, stratify=train_df["label"])
    test_set, val_set = train_test_split(test_val_set, test_size=0.5, random_state=42, stratify=test_val_set["label"])
    
    # Dataset class
    class TextDataset(Dataset):
        def __init__(self, texts, labels, tokenizer, max_length=128):
            self.texts = texts
            self.labels = labels
            self.tokenizer = tokenizer
            self.max_length = max_length
    
        def __len__(self):
            return len(self.texts)
    
        def __getitem__(self, idx):
            text = self.texts[idx]
            label = self.labels[idx]
            encoding = self.tokenizer.encode_plus(
                text,
                add_special_tokens=True,
                max_length=self.max_length,
                padding='max_length',
                truncation=True,
                return_tensors='pt'
            )
            return {
                'input_ids': encoding['input_ids'].flatten(),
                'attention_mask': encoding['attention_mask'].flatten(),
                'labels': torch.tensor(label, dtype=torch.long)
            }
    
    # Initialize tokenizer and datasets
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    train_dataset = TextDataset(train_set['text'], train_set['label'], tokenizer)
    val_dataset = TextDataset(val_set['text'], val_set['label'], tokenizer)
    test_dataset = TextDataset(test_set['text'], test_set['label'], tokenizer)
    
    # Initialize BERT model for classification
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=4)
    
    # Training arguments
    training_args = TrainingArguments(
        output_dir='./results',
        num_train_epochs=3,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        evaluation_strategy='epoch',
        save_strategy='epoch',
        logging_dir='./logs',
        learning_rate=2e-5,
        load_best_model_at_end=True
    )
    
    # Define Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset
    )
    
    # Train model
    trainer.train()
    
    # Evaluate model
    results = trainer.evaluate(test_dataset)
    print(f"Test Accuracy: {results['eval_accuracy']:.4f}")
    
    # Save the model and tokenizer
    model.save_pretrained("saved_model")
    tokenizer.save_pretrained("saved_tokenizer")
    

    Incremental training with least effort

    # Load the saved model and tokenizer
    model = BertForSequenceClassification.from_pretrained("saved_model")
    tokenizer = BertTokenizer.from_pretrained("saved_tokenizer")
    
    # Load new data
    new_data = (
        pl.read_csv("set2.txt", separator=";", has_header=False, new_columns=["text", "label"])
        .filter(pl.col("label").is_in(["fear", "joy", "sadness", "anger"]))
        .with_columns(pl.col("label").replace_strict(label_mapping, default="other").cast(pl.Int16))
    )
    
    # Create new dataset
    new_dataset = TextDataset(new_data['text'], new_data['label'], tokenizer)
    
    # Update training arguments for incremental training
    new_training_args = TrainingArguments(
        output_dir='./results_incremental',
        num_train_epochs=2,  # Fewer epochs since it's incremental
        per_device_train_batch_size=16,
        evaluation_strategy='epoch',
        logging_dir='./logs_incremental',
        learning_rate=2e-5,
        load_best_model_at_end=True
    )
    
    # Define new trainer
    new_trainer = Trainer(
        model=model,
        args=new_training_args,
        train_dataset=new_dataset,
        eval_dataset=val_dataset  # Validate on previous validation set
    )
    
    # Train on new data
    new_trainer.train()
    
    # Evaluate after retraining
    new_results = new_trainer.evaluate(test_dataset)
    print(f"Test Accuracy After Incremental Training: {new_results['eval_accuracy']:.4f}")
    
    # Save the updated model
    model.save_pretrained("saved_model_incremental")