Search code examples
matlabmachine-learningneural-networkclassificationsupervised-learning

Classification with neural network in matlab: get probability of element belonging to i-th class


I would like to solve a classification problem with Matlab. I have a dataset consisting of 3 classes and 1900 samples. Each sample is defined by 10 features and I have 900 samples for class '1', 500 for class '2' and 500 for class '3'.

I tried to use the standard patternnet tool in Matlab to train a neural network. I made different test with different number of neurons, from 1 to 100 but I always get bad performances for the classification.

So I took a look to the confusion matrix and I noticed that the problem is that the classifier is confusing classes '2' and '3'. What I tried next is create two neural networks:

  1. The first neural network is a 2-classes classifier, with class '1' and class '23' (the union of classes '2' and '3'). This first classification has a good accuracy for me (around 90%)
  2. The second neural network is again a 2-classes classifier which takes as input only elements of class '2' and '3'. The problem is that the accuracy of this second neural network is quite poor, around 55%.

So again I have some difficulties to improve the classification accuracy. I would like to make some test to see if I can improve the accuracy. My idea is to see what is the probability of each element to belong to a specific class. What I would like to do then is one of the following:

  1. Try to change the threshold value which determines the class of a sample. This would work if for example, all the elements that have a probability > 70% of being class '3' are really class '3' but if the probability is between 50% and 70% this elements are generally class '2' (I am just making up number to try to explain what I would like to test)
  2. Create a class '4' for samples that are to difficult to be classified. Again this would work if, for example elements that have a probability > 70% of being class '3' are really class '3' and I will consider class '4' elements with probability <70%. If this work I could have some elements that are "unknown class '4' " but elements classified as '2' or '3' would be correct with an high degree of accuracy

So first I would like to know if it is possible to retrieve the probability of each element to belong to a specific class and second if there is a standard method in Matlab to implement one of the two test that I would like to do. (Of course if someone has a better idea I am glad to test it) Sorry for the long description but I hope at least I explained what is my problem.


Solution

  • @MeSS83. In order for me to provide a proper example (with codes and everything) I had to write an entire answer. The easiest way to perform this multi-class classification with SVMs is by using LibSVM. LibSVM is a free SVMs library (you can download it here), that can also be installed and used in Matlab environment. Unzip the file, there is a matlab folder in which you'll find the installation guide and everything.

    Basically what you want to do is the One-vs-All SVM approach, that is you train N SVMs (where N is the number of classes) and each SVM is trained to separate a given class i from all the others (the i-th class will be positive and all not-i classes will be negative). Let's say TrainingSet, TrainingLabels, ValidationSet, ValidationLabels are your dataset (their names are rather straightforward) and numLabels is the number of labels (3 in your case).

    You can train these SVMs as follows:

    for k=1:numLabels
        % k-th class positive, all the other classes are negative
        LabelsRecoded(TrainingLabels==k)=1;
        LabelsRecoded(TrainingLabels~=k)=-1;
    
        model{k} = svmtrain(LabelsRecoded, TrainingSet, '-c 1 -b 1 -t 0');
    end
    

    In this code '-c 1 -b 1 -t 0' are LibSVM parameters for the SVM: c is the regulatization term (set as 1), -b 1 means that you want to gather also the output probabilities (a.k.a decision values) and -t 0 means you are using a linear kernel. More infos can be found in the readme inside the LibSVM package. As instead, model is a cell array in which the k-th element contains the structure regarding the SVM trained to separate the k-th class from all the others.

    The prediction phase has the following structure:

    LabelsRecoded=[]; % get rid of the results stored previously in the training phase
    for k=1:numLabels
        # same as before, but with validation labels
        LabelsRecoded(ValidationLabels==k)=1;
        LabelsRecoded(ValidationLabels~=k)=-1;
    
        [~,~,p] = svmpredict(LabelsRecoded, ValidationSet, model{k}, '-b 1');
        prob(:,k) = p(:,model{k}.Label==1);
    end
    

    Inside prob you'll have 3 columns (3 is the number of classes) containing the probability of the k-th class being positive (mind the model{k}.Label==1). Now you can gather the predicted labels according to the maximum probability value as follows:

    [~,PredictedLabels] = max(prob,[],2);
    

    Now you have both the Predicted Labels and the Validation Labels and you can evaluate the accuracy according to the standard formula.