Search code examples
neural-networkcategorical-dataloss-functionmulticlass-classification

Poor performance help- muti-class classification by ANN


I'm implementing a 7-class classification task with normalised features and one-hot encoded labels. However, the training and validation accuracies have been extremely poor.

enter image description here enter image description here enter image description here As shown, I normalised features from with StandardScaler() method and each feature vector turns out a 54-dim numpy array. Also, I one-encoded labels in the following manner.

enter image description here enter image description here

As illustrated below, the labels are (num_y, 7) numpy arrays.

enter image description here

My network architecture:

enter image description here

It is shown here how I designed my model. And I'm wonder if the poor result has something to do with the selection of loss function (I've been using Categorical Cross-Entropy)

enter image description here

I appreciate any response from you. Thanks a lot!


Solution

  • The use of accuracy is obviously wrong. The code I refer to is not provided in your question, but I can speculate that you are comparing the true labels with your model outputs. Your model probably returns a vector of dimensionality 7 which constitutes a probability density function over the classes (due to the softmax activation in your final layer) like this:

    model returns: (0.7 0 0.02 0.02 0.02 0.04 0.2) -- they sum to 1 because they represent probabilities

    and then you are comparing these numbers with: (1 0 0 0 0 0 0)

    what you have to do is translate the model output to the corresponding predicted label ((0.7 0 0.02 0.02 0.02 0.04 0.2) corresponds to (1 0 0 0 0 0 0) because the first output neuron has the larger value (0.7)). You may do that by applying a max function after your model outputs.

    To make sure thats whats wrong with your problem formulation print the vector you are comparing with the true labels to get your accuracy and check if they are 7 numbers that sum up to 1.