Search code examples
pythontensorflowkerasimage-classificationimage-augmentation

Image augmentation with Tensorflow so All classes have EXACT SAME number of images


I want to do multi class image classification for animal classification. The problem is my dataset has different number of images for each classes and the difference are quite awful. For example:

In this example the dataset contains 320 images of 3 classes. The class A has 125 images, the class B has 170 images, and the class C has only 25 images and I wish to augment those classes therefore there will be 200 images for each classes which means 600 images that uniformly distributed to those 3 classes.

However, in my case, there are 60 classes in my dataset. How can I augment all of them so they would have the exact same number of images for all the classes?


Solution

  • It would take considerable coding but you can use the ImageDataGenerator to produce augmented images and store them in a specified directory. Documentation for the generator is here. Alternatively you can use modules like cv2 or PIL that provide functions to transform images. Below is the code you can use with cv2. Note look up the cv2 documentation to see how to specify the image transforms as noted in the code comment. Code is below

    import os
    import cv2
    file_number =130 # set this to the number of files you want
    sdir=r'C:\Temp\dummydogs\train' # set this to the main directory that contains yor class directories
    slist=os.listdir(sdir)
    for klass in slist:
        class_path=os.path.join(sdir, klass)
        filelist=os.listdir(class_path)
        file_count=len(filelist)
        if file_count > file_number:
            # delete files from the klass directory because you have more than you need
            delta=file_count-file_number
            for i in range(delta):
                file=filelist[i]
                fpath=os.path.join (class_path,file)
                os.remove(fpath)
        else:
            # need to add files to this klass so do augmentation using cv3 image transforms
            label='-aug' # set this to a string that will be part of the augmented images file name 
            delta=file_number-file_count
            for i in range(delta):
                file=filelist[i]
                file_split=os.path.split(file)
                index=file_split[1].rfind('.')
                fname=file[:index]
                ext=file[index:]
                fnew_name=fname + '-' +str(i) +'-' + label + ext
                fpath=os.path.join(class_path,file)
                img=cv2.imread(fpath)
                img= cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 
                # look up cv2 documentation and apply image transformation code here
                dest_path=os.path.join(class_path, fnew_name)
                cv2.imwrite(dest_path,img)