Search code examples
pythonpytorchlstmmulticlass-classification

PyTorch LSTM for multiclass classification: TypeError: '<' not supported between instances of 'Example' and 'Example'


I am trying to modify the code in this Tutorial to adapt it to a multiclass data (I have 55 distinct classes). An error is triggered and I am uncertain of the root cause. The changes I made to this tutorial have been annotated in same-line comments.

One of two solutions would satisfy this questions:

(A) Help identifying the root cause of the error, OR

(B) A boilerplate script for multiclass classification using PyTorch LSTM

import spacy
import torchtext
from torchtext import data
import re

TEXT = data.Field(tokenize = 'spacy', include_lengths = True)
LABEL = data.LabelField(dtype = torch.float)
fields = [(None,None),('text', TEXT), ('wage_label', LABEL)]

train_torch, test_torch = data.TabularDataset.splits(path='/Users/jdmoore7/Desktop/Python Projects/560_capstone/', 
                                            format='csv', 
                                            train='train_text_target.csv', 
                                            test='test_text_target.csv', 
                                            fields=fields,
                                            skip_header=True)


import random
train_data, valid_data = train_torch.split(random_state = random.seed(SEED)) 

MAX_VOCAB_SIZE = 25_000

TEXT.build_vocab(train_data, 
                 max_size = MAX_VOCAB_SIZE, 
                 vectors = "glove.6B.100d", 
                 unk_init = torch.Tensor.normal_)

LABEL.build_vocab(train_data)

BATCH_SIZE = 64

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
    (train_data, valid_data, test_torch), 
    batch_size = BATCH_SIZE,
    sort_within_batch = True,
    device = device)

import torch.nn as nn

class RNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, output_dim, n_layers, 
                 bidirectional, dropout, pad_idx):

        super().__init__()

        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx = pad_idx)

        self.rnn = nn.LSTM(embedding_dim, 
                           hidden_dim, 
                           num_layers=n_layers, 
                           bidirectional=bidirectional, 
                           dropout=dropout)

        self.fc = nn.Linear(hidden_dim * 2, output_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, text, text_lengths):

        #text = [sent len, batch size]

        embedded = self.dropout(self.embedding(text))

        #embedded = [sent len, batch size, emb dim]

        #pack sequence
        packed_embedded = nn.utils.rnn.pack_padded_sequence(embedded, text_lengths)

        packed_output, (hidden, cell) = self.rnn(packed_embedded)

        #unpack sequence
        output, output_lengths = nn.utils.rnn.pad_packed_sequence(packed_output)

        #output = [sent len, batch size, hid dim * num directions]
        #output over padding tokens are zero tensors

        #hidden = [num layers * num directions, batch size, hid dim]
        #cell = [num layers * num directions, batch size, hid dim]

        #concat the final forward (hidden[-2,:,:]) and backward (hidden[-1,:,:]) hidden layers
        #and apply dropout

        hidden = self.dropout(torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1))

        #hidden = [batch size, hid dim * num directions]

        return self.fc(hidden)    

INPUT_DIM = len(TEXT.vocab)
EMBEDDING_DIM = 100
HIDDEN_DIM = 256
OUTPUT_DIM = len(LABEL.vocab) ### changed from previous value (1)
N_LAYERS = 2
BIDIRECTIONAL = True
DROPOUT = 0.5
PAD_IDX = TEXT.vocab.stoi[TEXT.pad_token]

model = RNN(INPUT_DIM, 
            EMBEDDING_DIM, 
            HIDDEN_DIM, 
            OUTPUT_DIM, 
            N_LAYERS, 
            BIDIRECTIONAL, 
            DROPOUT, 
            PAD_IDX)

import torch.optim as optim
optimizer = optim.Adam(model.parameters())


criterion = nn.CrossEntropyLoss() # Previously: criterion = nn.BCEWithLogitsLoss()
model = model.to(device)
criterion = criterion.to(device)

def binary_accuracy(preds, y):
    """
    Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8
    """

    #round predictions to the closest integer
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float() #convert into float for division 
    acc = correct.sum() / len(correct)
    return acc
def train(model, iterator, optimizer, criterion):

    epoch_loss = 0
    epoch_acc = 0

    model.train()

    for batch in iterator:

        optimizer.zero_grad()

        text, text_lengths = batch.text

        predictions = model(text, text_lengths).squeeze(1)

        loss = criterion(predictions, batch.label)

        acc = binary_accuracy(predictions, batch.label)

        loss.backward()

        optimizer.step()

        epoch_loss += loss.item()
        epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

def evaluate(model, iterator, criterion):

    epoch_loss = 0
    epoch_acc = 0

    model.eval()

    with torch.no_grad():

        for batch in iterator:

            text, text_lengths = batch.text

            predictions = model(text, text_lengths).squeeze(1)

            loss = criterion(predictions, batch.label)

            acc = binary_accuracy(predictions, batch.label)

            epoch_loss += loss.item()
            epoch_acc += acc.item()

    return epoch_loss / len(iterator), epoch_acc / len(iterator)

import time

def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

All the above ran smoothly, it's the next code block which triggers the error:

N_EPOCHS = 5

best_valid_loss = float('inf')

for epoch in range(N_EPOCHS):

    start_time = time.time()

    train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
    valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)

    end_time = time.time()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)

    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), 'tut2-model.pt')

    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc: {train_acc*100:.2f}%')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. Acc: {valid_acc*100:.2f}%')
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-888-c1b298b1eeea> in <module>
      7     start_time = time.time()
      8 
----> 9     train_loss, train_acc = train(model, train_iterator, optimizer, criterion)
     10     valid_loss, valid_acc = evaluate(model, valid_iterator, criterion)
     11 

<ipython-input-885-9a57198441ec> in train(model, iterator, optimizer, criterion)
      6     model.train()
      7 
----> 8     for batch in iterator:
      9 
     10         optimizer.zero_grad()

~/opt/anaconda3/lib/python3.7/site-packages/torchtext/data/iterator.py in __iter__(self)
    140         while True:
    141             self.init_epoch()
--> 142             for idx, minibatch in enumerate(self.batches):
    143                 # fast-forward if loaded from state
    144                 if self._iterations_this_epoch > idx:

~/opt/anaconda3/lib/python3.7/site-packages/torchtext/data/iterator.py in pool(data, batch_size, key, batch_size_fn, random_shuffler, shuffle, sort_within_batch)
    284     for p in batch(data, batch_size * 100, batch_size_fn):
    285         p_batch = batch(sorted(p, key=key), batch_size, batch_size_fn) \
--> 286             if sort_within_batch \
    287             else batch(p, batch_size, batch_size_fn)
    288         if shuffle:

TypeError: '<' not supported between instances of 'Example' and 'Example'

Lastly, the PyTorch forum has an issue opened for this error, however, the code that produced it is not similar so I understand that to be a separate issue.


Solution

  • The BucketIterator sorts the data to make batches with examples of similar length to avoid having too much padding. For that it needs to know what the sorting criterion is, which should be the text length. Since it is not fixed to a specific data layout, you can freely choose which field it should use, but that also means you must provide that information to sort_key.

    In your case, there are two possible fields, text and wage_label, and you want to sort it based on the length of the text.

    train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits(
        (train_data, valid_data, test_torch), 
        batch_size = BATCH_SIZE,
        sort_within_batch = True,
        sort_key = lambda x: len(x.text),
        device = device)
    

    You might be wondering why it worked in the tutorial but doesn't in your example. The reason is that if sort_key is not specified, it defers it to the underlying dataset. In the tutorial they used the IMDB dataset, which defines the sort_key to be x.text. Your custom dataset did not define that, so you need to specify it manually.