Search code examples
tensorflowmachine-learningkerasmulticlass-classification

Binary and multi-class classification code change


I am using almost similar code that I found here...

https://towardsdatascience.com/classify-butterfly-images-with-deep-learning-in-keras-b3101fe0f98

The example is related to binary classification. The data that I am testing with is calling for multi-class classification. I guess I need to change activation and loss function. Can I use the same code found here if I have more than 2 types?

https://github.com/bertcarremans/Vlindervinder/blob/master/model/CNN.ipynb


update: I have one more question. Is augmentation necessary if I have enough data?


Solution

  • Just change binary_crossentropy to categorical_crossentropy:

    cnn.compile(loss='categorical_crossentropy',
                optimizer='rmsprop',
                metrics=['accuracy'])
    

    If your labels are not one-hot encoded modify these lines:

    train_generator = train_datagen.flow_from_directory(
        'data/train',
        target_size=(IMG_SIZE,IMG_SIZE),
        batch_size = BATCH_SIZE,
        class_mode='categorical')
    
    validation_generator = validation_datagen.flow_from_directory(
        'data/validation',
        target_size=(IMG_SIZE,IMG_SIZE),
        batch_size = BATCH_SIZE,
        class_mode='categorical')