Search code examples
tensorflowkerasconv-neural-networkgenerative-adversarial-network

Feed image data without class label


I am trying to implement image super resolution using SRGAN. In the process, I used DIV2K dataset (http://data.vision.ee.ethz.ch/cvl/DIV2K/DIV2K_train_HR.zip) as my source. I have worked with image classification using CNN (I used keras.layers.convolutional.Conv2D). But in this case we don't have class label in my data source.

I have unzipped the file and kept in D:\Unzipped\DIV2K_train_HR. Then used following command to read the files.

img_dataset = tensorflow.keras.utils.image_dataset_from_directory("D:\\unzipped")

Then created the model as follows

model = Sequential()
model.add(Conv2D(filters=64,kernel_size=(3,3),activation="relu",input_shape=(256,256,3)))
model.add(AveragePooling2D(pool_size=(2,2)))
model.add(Conv2D(filters=64,kernel_size=(3,3),activation="relu"))
model.add(MaxPooling2D(pool_size=(2,2)))

model.compile(optimizer='sgd', loss='mse')

model.fit(img_dataset,batch_size=32, epochs=10)

But I am Graph execution error. I am unable to find the root cause behind this error. Is this error appearing as the class label is missing (I think as per code DIV2K_train_HR is treated as one class label)? Or is this happening due to images don't have one specific size?

Note: This code does not match with SRGAN architecture.


Solution

  • Yes, the error message is because you don't have labels in your dataset.

    As a first step in GAN network you need to create a discriminator model: given some image it should recognize if it is a real or fake image. You can take images from your dataset and label them as 1 ("real images"). Then generate "fake images" by down-sampling and up-sampling images from your dataset and label them as 0. Train your discriminator model so that it can distinguish between original and processed images.

    After that, you create generator model. The generator model takes a down-sampled version of the image as an input and creates an up-sampled version in original resolution. GAN model combines generator and discriminator models by passing output from generator to discriminator. The target label is 1, i.e. we want generator create up-sampled versions of images, which discriminator can't distinguish from the real ones. Now train GAN network (set 'trainable' to false for discriminator model weights).

    After your generator manages to produce images, which discriminator can't distinguish from the real, you take them, label as 0 and train discriminator again. Then train generator again etc.

    The process continues until discriminator can't distinguish fake images from the real ones anymore (i.e. accuracy doesn't exceed 0.5).

    Please see a simple example on ("Generative Adversarial Networks"): https://github.com/ageron/handson-ml3/blob/main/17_autoencoders_gans_and_diffusion_models.ipynb This code is explained in ch. 17 in book "Hands-on Machine Learning with Scikit-Learn, Keras and TensorFlow (3rd edition)" by Aurélien Géron.