Search code examples
pythondeep-learninghash

Why am I keep running into the NaN problem when training CIBHash model?


Recently I have been trying to reproduce the result of https://github.com/zexuanqiu/CIBHash, however, I run into the loss explosion every time after evaluating. I am using cifar-10 dataset from official site.

For instance,

  • Run the code with python main.py cifar16 --train --dataset cifar10 --encode_length 16 --cuda (defaultly setting validate_frequency=20), the code would evaluate the performance of itself after epoch=20 and would continue training, the loss explosion problem occurs in epoch=21.
  • Run the code with python main.py cifar16 --train --dataset cifar10 --encode_length 16 --cuda --validate_frequency=3, setting validate_frequency=3, the loss explosion occurs in epoch=4, stably.

Sample output: sample output

Here's its run_training_session function:

def run_training_session(self, run_num, logger):
        self.train()
        
        # Scramble hyperparameters if number of runs is greater than 1.
        if self.hparams.num_runs > 1:
            logger.log('RANDOM RUN: %d/%d' % (run_num, self.hparams.num_runs))
            for hparam, values in self.get_hparams_grid().items():
                assert hasattr(self.hparams, hparam)
                self.hparams.__dict__[hparam] = random.choice(values)
        
        random.seed(self.hparams.seed)
        torch.manual_seed(self.hparams.seed)

        self.define_parameters()

        # if encode_length is 16, then al least 80 epochs!
        if self.hparams.encode_length == 16:
            self.hparams.epochs = max(80, self.hparams.epochs)

        logger.log('hparams: %s' % self.flag_hparams())
        
        device = torch.device('cuda' if self.hparams.cuda else 'cpu')
        self.to(device)

        optimizer = self.configure_optimizers()
        train_loader, val_loader, _, database_loader = self.data.get_loaders(
            self.hparams.batch_size, self.hparams.num_workers,
            shuffle_train=True, get_test=False)
        best_val_perf = float('-inf')
        best_state_dict = None
        bad_epochs = 0

        try:
            for epoch in range(1, self.hparams.epochs + 1):
                forward_sum = {}
                num_steps = 0
                for batch_num, batch in enumerate(train_loader):
                    optimizer.zero_grad()

                    imgi, imgj, _ = batch
                    imgi = imgi.to(device)
                    imgj = imgj.to(device)

                    forward = self.forward(imgi, imgj, device)

                    for key in forward:
                        if key in forward_sum:
                            forward_sum[key] += forward[key]
                        else:
                            forward_sum[key] = forward[key]
                    num_steps += 1

                    if math.isnan(forward_sum['loss']):
                        logger.log('Stopping epoch because loss is NaN')
                        break

                    forward['loss'].backward()
                    optimizer.step()

                if math.isnan(forward_sum['loss']):
                    logger.log('Stopping training session because loss is NaN')
                    break
                
                logger.log('End of epoch {:3d}'.format(epoch), False)
                logger.log(' '.join([' | {:s} {:8.4f}'.format(
                    key, forward_sum[key] / num_steps)
                                     for key in forward_sum]), True)

                if epoch % self.hparams.validate_frequency == 0:
                    print('evaluating...')
                    val_perf = self.evaluate(database_loader, val_loader, self.data.topK, device)
                    logger.log(' | val perf {:8.4f}'.format(val_perf), False)

                    if val_perf > best_val_perf:
                        best_val_perf = val_perf
                        bad_epochs = 0
                        logger.log('\t\t*Best model so far, deep copying*')
                        best_state_dict = deepcopy(self.state_dict())
                    else:
                        bad_epochs += 1
                        logger.log('\t\tBad epoch %d' % bad_epochs)

                    if bad_epochs > self.hparams.num_bad_epochs:
                        break

        except KeyboardInterrupt:
            logger.log('-' * 89)
            logger.log('Exiting from training early')

        return best_state_dict, best_val_perf

And here's the forward function of CIBHash model:

 def forward(self, imgi, imgj, device):
        imgi = self.vgg.features(imgi)
        imgi = imgi.view(imgi.size(0), -1)
        imgi = self.vgg.classifier(imgi)
        prob_i = torch.sigmoid(self.encoder(imgi))
        z_i = hash_layer(prob_i - torch.empty_like(prob_i).uniform_().to(prob_i.device))

        imgj = self.vgg.features(imgj)
        imgj = imgj.view(imgj.size(0), -1)
        imgj = self.vgg.classifier(imgj)
        prob_j = torch.sigmoid(self.encoder(imgj))
        z_j = hash_layer(prob_j - torch.empty_like(prob_j).uniform_().to(prob_j.device))

        kl_loss = (self.compute_kl(prob_i, prob_j) + self.compute_kl(prob_j, prob_i)) / 2
        contra_loss = self.criterion(z_i, z_j, device)
        loss = contra_loss + self.hparams.weight * kl_loss

        return {'loss': loss, 'contra_loss': contra_loss, 'kl_loss': kl_loss}

I have tried to replace to z_i and z_j as instructed in https://github.com/zexuanqiu/CIBHash/issues/6, however, it failed to prevent the NaN problem.

I have tried gradient_gripping method but it came into no use.

According to the reply of the author, they didn't come across any NaN problem when they were training the model. (https://github.com/zexuanqiu/CIBHash/issues/7)

I expect the code to finish training session without the occurrence of NaN problem. Can anyone be so kind to tell me what factor may cause this problem? Or is there any potential solution to the NaN loss problem?


Solution

  • It turns out that this problem is caused by lack of GPU memory and some kind of unknown bug in some previous CUDA versions?

    I have tried:

    • Setting Learning Rate to 0: Did not solve the issue.
    • Switching to CPU Execution: Loss explosion issue did not occur on CPU, but it was tooooo slow.
    • Modifying forward Function Code: Followed your code modification, but problem persisted.
    • Upgrading PyTorch and Related Libraries: Tried upgrading to versions with 'cuda' suffix (from Torch ==1.4.0 to Torch==1.5.0+cuda102), but loss explosion occurred earlier (before eval, in the first epoch).
    • Downgrading CUDA Version: Attempted downgrade to CUDA 9.2, encountered new errors. Since my GPU 3070-8G's compute capability is 8.5 but CUDA 9.2 only support it in the range 3-7.
    • Changing Data Type Precision to float64: Consumed excessive memory, necessitating reduction of batch size to 16. Decreasing Batch Size: Set batch size to 2, mitigated the problem, but it hardly improve the model after evaluating. A batch size of 4 leads to the loss NaN problem.
    • Migrating to Alternative Environment: Shifted to a Colab environment, indicating a possible issue with GPU memory constraints.
    • Upgrading CUDA Version to 11.1: Upgraded to a newer CUDA version, encountered MemoryError.

    Here's some information of difference performance on different GPUs that I have gathered or tried:

    My Machine:

    • GPU: NVIDIA GeForce RTX 3070 8G
    • CUDA Version:
      • Known NaN issues occurred with CUDA 10.1;
      • Compute capability compatibility problem in CUDA 9.2;
      • MemoryError problem in CUDA 11.1;

    2080Ti Machine:

    • GPU: 8 NVIDIA GeForce RTX 2080 Ti (11GB perhaps?)
    • CUDA Version: 10.2
    • Driver Version: 440.33.01
    • Success, no NaN problem.

    Colab Environment:

    • GPU: NVIDIA T4 (Tesla T4)
    • VRAM: 15GB
    • CUDA Version: 12+
    • The graph in Colab shows that the model occupied 9.2GB of all GPU RAM.
    • Success, no NaN problem.

    A40 GPU Environment:

    • GPU: NVIDIA A40
    • VRAM: Not specified
    • CUDA Version: Not specified, but known to work fine in this environment.
    • Success, no NaN problem.

    In conclusion, I suspect that the MemoryError issue was not properly handled in older versions of CUDA, as reported in https://github.com/ultralytics/ultralytics/issues/5294: CUDA errors like "out of memory" may also lead to NaN results.

    My suspicion is that older versions of CUDA may lack a proper error handling mechanism for NaN caused by insufficient memory, however, I have too few evidence of this problem. If anyone knows more specific detail of this problem, please contact me!