Search code examples
pythontensorflowkerascomputer-visionconv-neural-network

Split dataset of images into train test split for CNN


I'm training a CNN on kaggle and my data consists of two things: 1 csv file of labels and 1 folder of images. How can I split the data on kaggle into train test split? Thanks.

enter image description here

Here is one example image:

enter image description here

and the associated label(from the csv):

enter image description here


Solution

  • The function below creates train, test, and validation generators are given: source dir - full path to the directory containing all the images cvs_path - path to CSV file that has a column (x_col) containing a string of the filename and a column (y_col) that contains the string of the class associated filename

    note: source_dir/filename results in a path to the file in the source_dir This function automatically determines the batch_size for the generator and steps to us in model.fit so that you go through the train, test, or validation images exactly once per epoch. max_batch_size specifies the largest batch size you allow based on memory constraints train_split - float between 0 and 1 specifying the percentage of images used for training test_split - float between 0 and 1 specifying the percentage of images used for training note the validation_split is calculated internally as 1 - train_split - test_split target_size= tuple(height, width) input images are adjust to scale - float- pixels are rescaled to pixels* scale ( typically 1/255) class_mode - see keras flow_from_dataframe for details typically use 'categorical'

    import os
    import pandas as pd
    import tensorflow as tf
    from tensorflow.keras.preprocessing.image import ImageDataGenerator
    def train_test_valid_split(source_dir, cvs_path,max_batch_size, train_split, test_split, x_col, y_col, class_mode, target_size, scale):
        data=pd.read_csv(cvs_path).copy()
        te_split=test_split/(1-train_split)    
        train_df=data.sample(n=None, frac=train_split, replace=False, weights=None, random_state=123, axis=0)
        tr_batch_size= max_batch_size
        tr_steps=int(len(train_df.index)//tr_batch_size)     
        dummy_df=data.drop(train_df.index, axis=0, inplace=False)     
        test_df=dummy_df.sample(n=None, frac=te_split, replace=False, weights=None, random_state=123, axis=0)
        te_batch_size, te_steps=get_bs(len(test_df.index),max_batch_size )    
        valid_df=dummy_df.drop(test_df.index, axis=0)
        v_batch_size,v_steps=get_bs(len(valid_df.index), max_batch_size)
        gen=ImageDataGenerator(rescale=scale)
        train_gen=gen.flow_from_dataframe(dataframe=train_df, directory=source_dir,batch_size=tr_batch_size, x_col=x_col, y_col=y_col,
                                          target_size=target_size, class_mode=class_mode,seed=123,  validate_filenames=False)    
        test_gen=gen.flow_from_dataframe(dataframe=test_df, directory=source_dir, batch_size=te_batch_size, x_col=x_col, y_col=y_col,
                                         target_size=target_size, class_mode=class_mode,  shuffle=False,validate_filenames=False)
        valid_gen=gen.flow_from_dataframe(dataframe=valid_df, directory=source_dir,batch_size=v_batch_size, x_col=x_col, y_col=y_col, 
                                          target_size=target_size, class_mode=class_mode, shuffle=False,validate_filenames=False)    
        return train_gen, tr_steps, test_gen, te_steps, valid_gen , v_steps
    
    def get_bs(length, b_max):
        batch_size=sorted([int(length/n) for n in range(1,length+1) if length % n ==0 and length/n<=b_max],reverse=True)[0]
        steps=int(length//batch_size)
        return batch_size, steps
    
    

    the CSV file is of the form

        file_id     class_id
    0   00000.jpg   AFRICAN CROWNED CRANE
    1   00001.jpg   AFRICAN CROWNED CRANE
    2   00002.jpg   AFRICAN CROWNED CRANE
    3   00003.jpg   AFRICAN CROWNED CRANE
    4   00004.jpg   AFRICAN CROWNED CRANE
    5   00005.jpg   AFRICAN CROWNED CRANE
    6   00006.jpg   AFRICAN CROWNED CRANE
    7   00007..jpg  AFRICAN CROWNED CRANE
    8   00008..jpg  AFRICAN CROWNED CRANE
    

    Below is an example of the use

    source_dir=r'c:\temp\birds\consolidated_images'
    cvs_path=r'c:\temp\birds\birds.csv'
    train_split=.8
    test_split=.1
    x_col='file_id'
    y_col='class_id'
    target_size=(224,224)
    scale=1/127.5-1
    max_batch_size=32
    class_mode='categorical'
    train_gen, train_steps, test_gen, test_steps, valid_gen, valid_steps=train_test_valid_split(source_dir,
                    cvs_path, max_batch_size, train_split, test_split, x_col, y_col, class_mode, target_size, scale)
    print ('train steps: ', train_steps, '  test steps: ', test_steps, '  valid steps: ', valid_steps)
    

    results from execution are

    Found 30172 non-validated image filenames belonging to 250 classes.
    Found 3772 non-validated image filenames belonging to 250 classes.
    Found 3771 non-validated image filenames belonging to 250 classes.
    train steps:  942   test steps:  164   valid steps:  419
    

    now use these generators

    epochs= 20 # set to what you want
    history=model.fit(x=train_gen, epochs=epochs,steps_per_epoch=train_steps,   
                validation_data=valid_gen, validation_steps=valid_steps,
                shuffle=False,  verbose=1)
    

    after training

    accuracy=model.evaluate(test_gen, steps=test_steps)[1]*100
    print ('Model accuracy on test set is', accuracy)
    

    or to do predictions

    predictions=model.predict(test_gen, steps=test_steps, verbose=1)