Search code examples
pythontensorflowtypeerrorgenerative-adversarial-network

TypeError: Input 'y' of 'Sub' Op has type float32 that does not match type uint8 of argument 'x'


I'm working on a GAN with generator and discriminator.

@tf.function
def train_step(input_image, target, step):
  with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
    gen_output = generator(input_image, training=True)

    disc_real_output = discriminator([input_image, target], training=True)
    disc_generated_output = discriminator([input_image, gen_output], training=True)

    gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    disc_loss = discriminator_loss(disc_real_output, disc_generated_output)

  generator_gradients = gen_tape.gradient(gen_total_loss,
                                          generator.trainable_variables)
  discriminator_gradients = disc_tape.gradient(disc_loss,
                                               discriminator.trainable_variables)

  generator_optimizer.apply_gradients(zip(generator_gradients,
                                          generator.trainable_variables))
  discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
                                              discriminator.trainable_variables))

  with summary_writer.as_default():
    tf.summary.scalar('gen_total_loss', gen_total_loss, step=step//1000)
    tf.summary.scalar('gen_gan_loss', gen_gan_loss, step=step//1000)
    tf.summary.scalar('gen_l1_loss', gen_l1_loss, step=step//1000)
    tf.summary.scalar('disc_loss', disc_loss, step=step//1000)

This function throws an error:

TypeError: in user code:

    File "/tmp/ipykernel_34/3224399777.py", line 9, in train_step  *
        gen_total_loss, gen_gan_loss, gen_l1_loss = generator_loss(disc_generated_output, gen_output, target)
    File "/tmp/ipykernel_34/3072633757.py", line 5, in generator_loss  *
        l1_loss = tf.reduce_mean(tf.abs(target - gen_output))

    TypeError: Input 'y' of 'Sub' Op has type float32 that does not match type uint8 of argument 'x'.

But I try to subtract it manually, it works just fine, they are both float32

target - gen_output
<tf.Tensor: shape=(1, 256, 256, 3), dtype=float32, numpy=
array([[[[185.98402 , 151.92749 ,  81.13361 ],
         [186.15788 , 151.78894 ,  80.930176],
         [185.86765 , 151.81358 ,  80.65687 ],
         ...,
         [183.64613 , 151.91382 ,  87.36469 ],
         [183.17218 , 152.08833 ,  86.43396 ],
         [183.51439 , 152.04149 ,  87.40147 ]],
...

Solution

  • Just convert target to float32 from the beginning.

    target.asarray(tf.float32)