Search code examples
pythondeep-learningpytorchclassification

Loss remains unchanged in a multiclass classification deep model


Loss and accuracy metrics of my classification model remain unchanged unless I set the number of logits produced by the model to more than what the actual number of classes is.

I have checked several other questions here regarding unchanged or non-decreasing losses on different model architectures but none of them seem to apply to my specific issue.

The problem I'm currently facing is that the changes in my loss function (which is CrossEntropyLoss() by the way), are near non-existent (in the 0.0001s) unless I make it so that the model produces at least N_c + 1 logits in its output where N_c is the number of classes in my dataset. Note that I said at least because it works with literally any value N >= N_c + 1.

More specifically, I am attempting to implement the VTCNN2 model for signal classification on the the RML2016.10a dataset (implementation is largely borrowed from this repository with minor modifications to change the dataset from 10b to 10a and the number of classes from 10 to 11).

This dataset has a total 11 different modulations (read, "classes") for each signal. However, if I initialize my VTCNN2 with 11 classes I run into the issue I've explained whereas if I set the number of classes to anything equal to 12 or above (without touching anything else) my model manages to learn.

I'm using Poutyne to do away with having to type out the boilerplate code and also to be sure I'm not doing anything silly (like forgetting to zero-out the gradients at the start of each epoch and the like).

I would like to not clutter my question with snippets of my code but if I need to provide any specifics I'd be more than happy to update it.

EDIT: Here are the different pieces of my program for reference.

data.py (This assumes that the file RML2016.10a_dict.dat which is downloaded from here, resides next to this script)

import gc
import itertools
import os
import pickle
from typing import Tuple

import numpy as np
import torch
from tqdm import tqdm

# Modulation types
MODULATIONS = {
    "8PSK": 0,
    "AM-DSB": 1,
    "AM-SSB": 2,
    "BPSK": 3,
    "CPFSK": 4,
    "GFSK": 5,
    "PAM4": 6,
    "QAM16": 7,
    "QAM64": 8,
    "QPSK": 9,
    "WBFM": 10,
}

# Signal-to-Noise Ratios
SNRS = [-20, -18, -16, -14, -12, -10, -8, -6, -4, -2, 0, 2, 4, 6, 8, 10, 12, 14, 16, 18]


class RadioML2016(torch.utils.data.Dataset):
    URL = "."
    modulations = MODULATIONS
    snrs = SNRS

    def __init__(self, data_dir: str = ".", file_name: str = "RML2016.10a_dict.dat"):
        self.file_name = file_name
        self.data_dir = data_dir
        self.n_classes = len(self.modulations)
        self.X, self.y = self.load_data()
        gc.collect()

    def load_data(self) -> Tuple[np.ndarray, np.ndarray]:
        """Load data from file"""
        print("Loading dataset from file...")
        with open(os.path.join(self.data_dir, self.file_name), "rb") as f:
            data = pickle.load(f, encoding="latin1")

        X, y = [], []
        print("Processing dataset")
        for mod, snr in tqdm(list(itertools.product(self.modulations, self.snrs))):
            X.append(data[(mod, snr)])

            for i in range(data[(mod, snr)].shape[0]):
                y.append((mod, snr))

        X = np.vstack(X)

        return X, y

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Load a batch of input and labels"""
        x, (mod, snr) = self.X[idx], self.y[idx]
        y = self.modulations[mod]
        x, y = torch.from_numpy(x), torch.tensor(y, dtype=torch.long)
        x = x.to(torch.float).unsqueeze(0)
        return x, y

    def __len__(self) -> int:
        return self.X.shape[0]

models.py

import torch
import torch.nn as nn


class VT_CNN2(torch.nn.Module):
    def __init__(
        self,
        n_classes: int = 10,
        dropout: float = 0.5,
        device: str = "cuda" if torch.cuda.is_available() else "cpu",
    ):
        super(VT_CNN2, self).__init__()

        self.device = device
        self.loss = nn.CrossEntropyLoss()

        self.model = nn.Sequential(
            nn.ZeroPad2d(
                padding=(
                    2,
                    2,
                    0,
                    0,
                )
            ),  # zero pad front/back of each signal by 2
            nn.Conv2d(
                in_channels=1, out_channels=256, kernel_size=(1, 3), stride=1, padding=0
            ),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.ZeroPad2d(
                padding=(
                    2,
                    2,
                    0,
                    0,
                )
            ),  # zero pad front/back of each signal by 2
            nn.Conv2d(
                in_channels=256,
                out_channels=80,
                kernel_size=(2, 3),
                stride=1,
                padding=0,
            ),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Flatten(),
            nn.Linear(in_features=10560, out_features=256, bias=True),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(in_features=256, out_features=n_classes, bias=True),
        )

    def forward(self, x):
        return self.model(x)

    @torch.no_grad()
    def predict(self, x: torch.Tensor):
        x = x.to(self.device)
        y_pred = self.model(x)
        y_pred = y_pred.to("cpu")
        y_pred = torch.softmax(y_pred, dim=-1)
        values, indices = torch.max(y_pred, dim=-1)
        indices = indices.numpy()
        return indices

train.ipynb (convert to a .py file with nbconvert)

#!/usr/bin/env python
# coding: utf-8

# In[1]:


import os
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from poutyne.framework import Model
from poutyne.framework.callbacks import ModelCheckpoint
from torchmetrics.classification.accuracy import MulticlassAccuracy
from models import *
from data import *


# In[2]:


np.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False


# In[3]:


NUM_CLASSES = 12


# In[4]:


net = VT_CNN2(
    n_classes=NUM_CLASSES,
    dropout=0.2
)


# In[5]:


dataset = RadioML2016()


# In[6]:


total = len(dataset)
lengths = [int(len(dataset)*0.5)]
lengths.append(total - lengths[0])
print("Splitting into {} train and {} val".format(lengths[0], lengths[1]))
train_set, val_set = random_split(dataset, lengths)


# In[7]:


train_dataloader = DataLoader(train_set, batch_size=512)
val_dataloader = DataLoader(val_set, batch_size=512)


# In[8]:


os.makedirs("models", exist_ok=True)
checkpoint = ModelCheckpoint(
    filename=os.path.join("models", "vtcnn2.pt"),
    monitor="val_loss",
    save_best_only=True
)
callbacks = [checkpoint]


# In[9]:


top3 = MulticlassAccuracy(num_classes=NUM_CLASSES, top_k=3)
top5 = MulticlassAccuracy(num_classes=NUM_CLASSES, top_k=5)
metrics = ["acc", top3, top5]


# In[10]:


model = Model(
    network=net,
    optimizer="AdamW",
    loss_function=nn.CrossEntropyLoss(),
    batch_metrics=metrics
)


# In[11]:


model.cuda()
model.fit_generator(
    train_dataloader,
    val_dataloader,
    epochs=100,
    callbacks=callbacks
)

Solution

  • Apparently the PyTorch port of the exact same TensorFlow model results in vanishing gradients unless bias is set to False in the first convolutional layer.

    I have no idea why setting bias=False alleviates vanishing gradients but that seems to have solved the issue.