Search code examples
rdecision-treecross-validationkaggleprecision-recall

How to get F1,Precision and Recall for a Cross Validated Data Set in R


I have two data sets.

train <- read.csv("train.csv")
test  <- read.csv("test.csv")

The data in train set look as below.

> str(train)
'data.frame':   891 obs. of  12 variables:
 $ PassengerId: int  1 2 3 4 5 6 7 8 9 10 ...
 $ Survived   : Factor w/ 2 levels "0","1": 1 2 2 2 1 1 1 1 2 2 ...
 $ Pclass     : int  3 1 3 1 3 3 1 3 3 2 ...
 $ Name       : Factor w/ 891 levels "Abbing, Mr. Anthony",..: 109 191 358 
 277 16 559 520 629 417 581 ...
 $ Sex        : Factor w/ 2 levels "female","male": 2 1 1 1 2 2 2 2 1 1 ...
 $ Age        : num  22 38 26 35 35 NA 54 2 27 14 ...
 $ SibSp      : int  1 1 0 1 0 0 0 3 0 1 ...
 $ Parch      : int  0 0 0 0 0 0 0 1 2 0 ...
 $ Ticket     : Factor w/ 681 levels "110152","110413",..: 524 597 670 50 473 276 86 396 345 133 ...
 $ Fare       : num  7.25 71.28 7.92 53.1 8.05 ...
 $ Cabin      : Factor w/ 148 levels "","A10","A14",..: NA 83 NA 57 NA NA 131 NA NA NA ...
 $ Embarked   : Factor w/ 4 levels "","C","Q","S": 4 2 4 4 4 3 4 4 4 2 ...

The data in test set look as below.

> str(test)
'data.frame':   418 obs. of  11 variables:
 $ PassengerId: int  892 893 894 895 896 897 898 899 900 901 ...
 $ Pclass     : int  3 3 2 3 3 3 3 2 3 3 ...
 $ Name       : Factor w/ 418 levels "Abbott, Master. Eugene Joseph",..: 210 
409 273 414 182 370 85 58 5 104 ...
 $ Sex        : Factor w/ 2 levels "female","male": 2 1 2 2 1 2 1 2 1 2 ...
 $ Age        : num  34.5 47 62 27 22 14 30 26 18 21 ...
 $ SibSp      : int  0 1 0 0 1 0 0 1 0 2 ...
 $ Parch      : int  0 0 0 0 1 0 0 1 0 0 ...
 $ Ticket     : Factor w/ 363 levels "110469","110489",..: 153 222 74 148 
 139 262 159 85 101 270 ...
 $ Fare       : num  7.83 7 9.69 8.66 12.29 ...
 $ Cabin      : Factor w/ 77 levels "","A11","A18",..: 1 1 1 1 1 1 1 1 1 1 ...
 $ Embarked   : Factor w/ 3 levels "C","Q","S": 2 3 2 3 3 3 2 3 1 3 ... 

I am using decison tree as my classifier. I want to use 10 fold cross validation to train and evaluate the train set. For that I am using carrot package.

library(caret)
tc <- trainControl("cv",10)
rpart.grid <- expand.grid(.cp=0.2)

(train.rpart <- train(  Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare 
                       + Embarked, 
                       data=train, 
                       method="rpart",
                       trControl=tc,
                       na.action = na.omit,
                       tuneGrid=rpart.grid))

From here, I am able to get a value for the accuracy of the cross validation.

712 samples
  7 predictor
  2 classes: '0', '1' 

No pre-processing
Resampling: Cross-Validated (10 fold) 
Summary of sample sizes: 641, 641, 640, 640, 641, 641, ... 
Resampling results:

  Accuracy   Kappa    
  0.7794601  0.5334528

Tuning parameter 'cp' was held constant at a value of 0.2 

My question is how to find precision, recall and F1 for the 10-fold cross validated data set in a similar manner?


Solution

  • The current approach reads the survival outcome as integer, which leads rpart to perform regression rather than classification. Better to recode to a factor level.

    Evaluation metrics such as precision, recall, and F1 are available via the wonderful confusionMatrix function.

    library(caret)
    train <- read.csv("train.csv")
    test  <- read.csv("test.csv")
    tc <- trainControl("cv",10)
    rpart.grid <- expand.grid(.cp=0.2)
    
    # Convert variable interpreted as integer to factor
    train$Survived <- as.factor(train$Survived)
    
    (train.rpart <- train(  Survived ~ Pclass + Sex + Age + SibSp + Parch + Fare 
                        + Embarked, 
                        data=train, 
                        method="rpart",
                        trControl=tc,
                        na.action = na.omit,
                        tuneGrid=rpart.grid))
    # Predict
    pred <- predict(train.rpart, train) 
    
    # Produce confusion matrix from prediction and data used for training
    cf <- confusionMatrix(pred, train.rpart$trainingData$.outcome, mode = "everything")
    print(cf)