Search code examples
rclassificationpredictionsom

How can I use SOM algorithm for classification prediction


I would like to see If SOM algorithm can be used for classification prediction. I used to code below but I see that the classification results are far from being right. For example, In the test dataset, I get a lot more than just the 3 values that I have in the training target variable. How can I create a prediction model that will be in alignment to the training target variable?

library(kohonen)
    library(HDclassif)
    data(wine)
    set.seed(7)

    training <- sample(nrow(wine), 120)
    Xtraining <- scale(wine[training, ])
    Xtest <- scale(wine[-training, ],
                   center = attr(Xtraining, "scaled:center"),
                   scale = attr(Xtraining, "scaled:scale"))

    som.wine <- som(Xtraining, grid = somgrid(5, 5, "hexagonal"))


som.prediction$pred <- predict(som.wine, newdata = Xtest,
                          trainX = Xtraining,
                          trainY = factor(Xtraining$class))

And the result:

$unit.classif

 [1]  7  7  1  7  1 11  6  2  2  7  7 12 11 11 12  2  7  7  7  1  2  7  2 16 20 24 25 16 13 17 23 22
[33] 24 18  8 22 17 16 22 18 22 22 18 23 22 18 18 13 10 14 15  4  4 14 14 15 15  4

Solution

  • This might help:

    • SOM is an unsupervised classification algorithm, so you shouldn't expect it to be trained on a dataset that contains a classifier label (if you do that it will need this information to work, and will be useless with unlabelled datasets)
    • The idea is that it will kind of "convert" an input numeric vector to a network unit number (try to run your code again with a 1 per 3 grid and you'll have the output you expected)
    • You'll then need to convert those network units numbers back into the categories you are looking for (that is the key part missing in your code)

    Reproducible example below will output a classical classification error. It includes one implementation option for the "convert back" part missing in your original post.

    Though, for this particular dataset, the model overfitts pretty quickly: 3 units give the best results.

    #Set and scale a training set (-1 to drop the classes)
    data(wine)
    set.seed(7)
    training <- sample(nrow(wine), 120)
    Xtraining <- scale(wine[training, -1])
    
    #Scale a test set (-1 to drop the classes)
    Xtest <- scale(wine[-training, -1],
                   center = attr(Xtraining, "scaled:center"),
                   scale = attr(Xtraining, "scaled:scale"))
    
    #Set 2D grid resolution
    #WARNING: it overfits pretty quickly
    #Errors are 36% for 1 unit, 63% for 2, 93% for 3, 89% for 4
    som_grid <- somgrid(xdim = 1, ydim=3, topo="hexagonal")
    
    #Create a trained model
    som_model <- som(Xtraining, som_grid)
    
    #Make a prediction on test data
    som.prediction <- predict(som_model, newdata = Xtest)
    
    #Put together original classes and SOM classifications
    error.df <- data.frame(real = wine[-training, 1],
                           predicted = som.prediction$unit.classif)
    
    #Return the category number that has the strongest association with the unit
    #number (0 stands for ambiguous)
    switch <- sapply(unique(som_model$unit.classif), function(x, df){
      cat <- as.numeric(names(which.max(table(
        error.df[error.df$predicted==x,1]))))
      if(length(cat)<1){
        cat <- 0
      }
      return(c(x, cat))
    }, df = data.frame(real = wine[training, 1], predicted = som_model$unit.classif))
    
    #Translate units numbers into classes
    error.df$corrected <- apply(error.df, MARGIN = 1, function(x, switch){
      cat <- switch[2, which(switch[1,] == x["predicted"])]
      if(length(cat)<1){
        cat <- 0
      }
      return(cat)
    }, switch = switch)
    
    #Compute a classification error
    sum(error.df$corrected == error.df$real)/length(error.df$real)