Search code examples
pythonmachine-learningdeep-learningpytorchgenerative-adversarial-network

My training function ends unexpectedly after only a few steps


I am trying to run Pix2Pix, however my training function suddenly stops during the first 1k steps with no errors. I have used PyTorch for creating the discriminator and the generator. Below is the code with 2 functions responsible for training, one for training each step and one for fitting the model.

Training Step Function:

def train_step(input_image, target, step):
    generator.train()
    discriminator.train()

    # Forward pass
    gen_output = generator(input_image)

    disc_real_output = discriminator(input_image, target)
    disc_generated_output = discriminator(input_image, gen_output)

    # Compute losses
    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, 
        gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

    # Backward pass
    generator_optimizer.zero_grad()
    discriminator_optimizer.zero_grad()

    gen_total_loss.backward(retain_graph=True)
    discriminator_optimizer.zero_grad()  # Clear the generator gradients for the 
        discriminator backward pass
    disc_loss.backward()

    # Update weights
    generator_optimizer.step()
    discriminator_optimizer.step()

    # Logging
    with torch.no_grad():
        writer.add_scalar('gen_total_loss', gen_total_loss.item(), global_step=step // 1000)
        writer.add_scalar('gen_gan_loss', gen_gan_loss.item(), global_step=step // 1000)
        writer.add_scalar('gen_l1_loss', gen_l1_loss.item(), global_step=step // 1000)
        writer.add_scalar('disc_loss', disc_loss.item(), global_step=step // 1000)

Fitting Function:

def fit(train_loader, test_loader, steps):
   example_target, example_input = next(iter(test_loader))
   start = time.time()

   for step, (target, input_image) in enumerate(train_loader):
    if (step) % 1000 == 0:
        display.clear_output(wait=True)

        if step != 0:
            print(f'Time taken for 1000 steps: {time.time()-start:.2f} sec\n')

        start = time.time()

        generate_images(generator, example_input, example_target)
        print(f"Step: {step//1000}k")

    train_step(input_image, target, step)

    # Training step
    if (step+1) % 10 == 0:
        print('.', end='', flush=True)

    # Save (checkpoint) the model every 5k steps
    if (step + 1) % 5000 == 0:
        torch.save({
            'generator_state_dict': generator.state_dict(),
            'discriminator_state_dict': discriminator.state_dict(),
            'generator_optimizer_state_dict': generator_optimizer.state_dict(),
            'discriminator_optimizer_state_dict': discriminator_optimizer.state_dict(),
        }, f'checkpoint_{step + 1}.pt')

I am new to using GANs and I am not sure what the issue is here. I have tried to check if there is any exception that occurs during the training loop and print it but nothing is printed.


Solution

  • The problem is with your for loop that iterates over training data:

    for step, (target, input_image) in enumerate(train_loader):
    

    The way it's written, it will iterate once over the data in train_loader, and stop.

    Instead, you want something like:

    total_steps = 0 
    max_steps = 1000000 # some large value
    while total_steps < max_steps:
        for step, (target, input_image) in enumerate(train_loader):
            # do something
        total_steps += step
    

    This will terminate after max_steps plus some leftover of steps.