Search code examples
rsvmprobabilityr-caret

How to list probability for each prediction in model output


Testing some models for predicting the species with a modified Iris dataset. Limiting to SVM and Random Forest for now. Running this in R-studio.

Abbreviated Set-up:

    library(caret)

    #data
    data(iris)

    #rename
    dataset <- iris

    #smaller sample
    sample_data <- dataset[sample(nrow(dataset), 60), ]

    #create some noise so model is less-than-perfect
    noise_df <- data.frame(
          Sepal.Length = c(5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 5.7, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 7.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0), 
          Sepal.Width = c(3.8, 3.8, 3.8, 3.8, 3.8, 3.8, 3.8, 3.8, 2.7, 2.7, 2.7, 2.8, 2.8, 2.8, 2.8, 2.8, 3.1, 3.1, 3.1, 3.1, 3.1, 3.1), 
          Petal.Length = c(5.2, 5.2, 5.3, 5.3, 5.4, 5.4, 5.4, 5.4, 5.5, 5.5, 5.5, 5.6, 5.6, 5.7, 5.7, 5.8, 1.3, 1.3, 1.3, 1.3, 1.3, 1.3), 
          Petal.Width = c(1.8, 1.8, 1.8, 1.9, 1.9, 1.9, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2), 
          Species = c("setosa","setosa", "setosa","setosa", "setosa","setosa","setosa","setosa", "setosa","setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "setosa", "virginica", "virginica", "virginica", "virginica", "virginica", "virginica")
          )

    #combine sample with noise
    dataset2 <- rbind(sample_data, noise_df)

    #split data into train/test
    set.seed(7)
    validation_index <- createDataPartition(dataset2$Species, p=0.70, list=FALSE)
    test_set <- dataset2[-validation_index,]
    train_set <- dataset2[validation_index,]

    #====================
    #build models
    #====================

    control <- trainControl(method="cv", number=10)
    metric <- "Accuracy"

    #random forest model
    set.seed(3)
    fit.rf <- train(Species~., data=train_set, method="rf", metric=metric, trControl=control)

    #svm model
    set.seed(3)
    fit.svm <- train(Species~., data=train_set, method="svmRadial", metric=metric, trControl=control)


    #====================
    #run model on test
    #====================
    predictions <- predict(fit.svm, test_set)
    confusionMatrix(predictions, test_set$Species)

Confusion Matrix output:

                Reference
    Prediction   setosa versicolor virginica
      setosa         11          0         3
      versicolor      0          3         0
      virginica       0          1         5

I'm wondering if it's possible to list the probability for each prediction. For example:

        setosa  versicolor  virginica   predicted
    1   0.9     0.0         0.1         setosa
    2   0.1     0.8         0.1         versicolor
    3   0.33    0.33        0.33        virginica

I would guess that Random Forest perhaps just lists 0 vs 1, but wondering if SVM has the option to break out the probabilities like the example above. If so, I'm not sure how to shape my data or the functions to use. Is it a decision_function or predict_proba function, but I'm not clear on how to correctly perform it in r.


Solution

  • For randomforest, the probability is the proportion of decision trees that predict each label, and you can do it using predict(.. , type="prob") :

    data.frame(predict(fit.rf,type="prob", newdata=test_set),
               predicted=predict(fit.rf, newdata=test_set))
    
           setosa versicolor virginica  predicted
    147  0.016      0.002     0.982  virginica
    15   0.908      0.068     0.024     setosa
    103  0.486      0.000     0.514  virginica
    118  0.416      0.056     0.528  virginica
    129  0.344      0.000     0.656  virginica
    39   0.388      0.080     0.532  virginica
    

    For kernlab svm you needa set prob.model = TRUE:

    set.seed(3)
    fit.svm <- train(Species~., data=train_set, method="svmRadial", metric=metric, trControl=control, prob.model = TRUE)
    
    data.frame(predict(fit.svm,newdata=test_set,type="prob"),
               predicted=predict(fit.svm,newdata=test_set))
    
            setosa  versicolor  virginica  predicted
    1  0.129916071 0.051873046 0.81821088  virginica
    2  0.884025291 0.030853736 0.08512097     setosa
    3  0.129054108 0.006256384 0.86468951  virginica
    4  0.104952659 0.124066424 0.77098092  virginica