Search code examples
deep-learningpytorchnormalizationnormalize

How to normalize pytorch model output to be in range [0,1]


lets say I have model called UNet

output = UNet(input)

that output is a vector of grayscale images shape: (batch_size,1,128,128)

What I want to do is to normalize each image to be in range [0,1].

I did it like this:

for i in range(batch_size):
   output[i,:,:,:] = output[i,:,:,:]/torch.amax(output,dim=(1,2,3))[i]

now every image in the output is normalized, but when I'm training such model, pytorch claim it cannot calculate the gradients in this procedure, and I understand why.

my question is what is the right way to normalize image without killing the backpropogation flow? something like

output = UNet(input)
output = output.normalize
output2 = some_model(output)
loss = ..
loss.backward()
optimize.step()

my only option right now is adding a sigmoid activation at the end of the UNet but i dont think its a good idea..

update - code (gen2,disc = unet,discriminator models. est_bias is some output):


update 2x code:

with torch.no_grad():
            est_bias_for_disc = gen2(input_img)

            est_bias_for_disc /= est_bias_for_disc.amax(dim=(1,2,3), keepdim=True)
        disc_fake_hat = disc(est_bias_for_disc.detach())
        disc_fake_loss = BCE(disc_fake_hat, torch.zeros_like(disc_fake_hat)) 

        disc_real_hat = disc(bias_ref)
        disc_real_loss = BCE(disc_real_hat, torch.ones_like(disc_real_hat))
        disc_loss = (disc_fake_loss + disc_real_loss) / 2
        if epoch<=epochs_till_gen2_stop:
            disc_loss.backward(retain_graph=True) # Update gradients    
            opt_disc.step() # Update optimizer

then theres seperate training:

 opt_gen2.zero_grad()
 est_bias = gen2(input_img)
 est_bias /= est_bias.amax(dim=(1,2,3), keepdim=True)
 disc_fake = disc(est_bias) 
 ADV_loss = BCE(disc_fake, torch.ones_like(disc_fake))
 gen2_loss = ADV_loss
 gen2_loss.backward() 
 opt_gen2.step()

Solution

  • You are overwriting the tensor's value because of the indexing on the batch dimension. Instead, you can perform the operation in vectorized form:

    output = output / output.amax(dim=(1,2,3), keepdim=True)
    

    The keepdim=True argument keeps the shape of torch.Tensor.amax's output equal to that of its inputs allowing you to perform an in-place operation with it.