i am currently training a CNN with the ASL dataset https://www.kaggle.com/datamunge/sign-language-mnist.
To optimize my accuracy I used the ImageDataGenerator from Keras. I wanted to print out the results of the Data Augmentation (image before and after the Data Augmentation). But I don't understand how to plot the results from datagen. This is my code:
datagen = keras.preprocessing.image.ImageDataGenerator(
featurewise_center=False, samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False, rotation_range=10,
zoom_range=0.1, width_shift_range=0.1,
height_shift_range=0.1, horizontal_flip=False,
vertical_flip=False)
datagen.fit(train_data)
result_data = datagen.flow(train_data, train_label, batch_size=128)
print(result_data)
train_data is a numpy array of shape (20, 28, 28, 1) and train_label(20, 1) as they are 20 images with 28*28 pixels and the third dimension for the usage in a CNN.
I would like to plot it with matploit lib but also happy with anything else (np array of the pixels).
If someone could also tell me how I can print the amount of data the datagen generated would be awesome.
Thank you in advance for your help.
First, you can create default DataGenerator to plot original images easily
datagenOrj = keras.preprocessing.image.ImageDataGenerator()
You can flow a small sample like the first five images into your 'datagen'. This generator gets images randomly. To making a proper comparison, small and certain sampling can be good for large dataset.
result_data = datagen.flow(train_data[0:5], train_label[0:5], batch_size=128)
result_data_orj = datagenOrj.flow(train_data[0:5], train_label[0:5], batch_size=128)
When you call the next() function, your data generator loads your first batch. The result should contain both train data and train label. You can access them by index.
def getSamplesFromDataGen(resultData):
x = resultData.next() #fetch the first batch
a = x[0] # train data
b = x[1] # train label
for i in range(0,5):
plt.imshow(a[i])
plt.title(b[i])
plt.show()
Be carefull about plotting. You may need to rescale your data. If your data type is float, you need to scale it between 0 and 1 and if your data type integer, you should scale it between 0 and 255. To do it you can use rescale property.
datagenOrj = keras.preprocessing.image.ImageDataGenerator(rescale=1.0/255.0)
I tried on my own dataset and it is working.
Best